Initialisation depot
This commit is contained in:
298
arti-api/auth-service/app.py
Normal file
298
arti-api/auth-service/app.py
Normal file
@@ -0,0 +1,298 @@
|
||||
from fastapi import FastAPI, HTTPException, Request, Response, Depends, Form
|
||||
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import jwt
|
||||
import bcrypt
|
||||
import ldap3
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import logging
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
"""
|
||||
This is an authentication service using FastAPI that verifies user credentials against Active Directory (AD)
|
||||
and issues JWT tokens for authenticated users. It supports cross-domain authentication and is designed to work
|
||||
with Traefik as a reverse proxy.
|
||||
This will be front end for apps
|
||||
"""
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
JWT_SECRET = os.getenv("JWT_SECRET", "your-super-secret-key-change-this")
|
||||
JWT_ALGORITHM = "HS256"
|
||||
TOKEN_EXPIRE_HOURS = int(os.getenv("TOKEN_EXPIRE_HOURS", "8"))
|
||||
|
||||
# Domain configuration for cross-domain auth
|
||||
ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",")
|
||||
AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld")
|
||||
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",")
|
||||
|
||||
app = FastAPI(title="Authentication Service", description="AD Authentication with JWT tokens")
|
||||
|
||||
# Add CORS middleware for cross-domain authentication
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, specify exact domains
|
||||
allow_credentials=True,
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Domain configuration for cross-domain auth
|
||||
ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",")
|
||||
AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld")
|
||||
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",")
|
||||
|
||||
# Active Directory Configuration
|
||||
AD_SERVER = os.getenv("AD_SERVER", "ldap://your-ad-server.com")
|
||||
AD_BASE_DN = os.getenv("AD_BASE_DN", "DC=yourdomain,DC=com")
|
||||
AD_USER_SEARCH_BASE = os.getenv("AD_USER_SEARCH_BASE", "CN=Users,DC=yourdomain,DC=com")
|
||||
AD_BIND_USER = os.getenv("AD_BIND_USER", "ReadUser") # Service account for LDAP bind
|
||||
AD_BIND_PASSWORD = os.getenv("AD_BIND_PASSWORD", "")
|
||||
|
||||
# Setup templates and static files
|
||||
templates = Jinja2Templates(directory="templates")
|
||||
app.mount("/static", StaticFiles(directory="static"), name="static")
|
||||
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str
|
||||
password: str
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str
|
||||
email: Optional[str] = None
|
||||
groups: list = []
|
||||
exp: datetime
|
||||
|
||||
def verify_ad_credentials(username: str, password: str) -> dict:
|
||||
"""
|
||||
Verify credentials against Active Directory
|
||||
Returns user info if valid, raises exception if invalid
|
||||
"""
|
||||
try:
|
||||
# Connect to AD server
|
||||
server = ldap3.Server(AD_SERVER, get_info=ldap3.ALL)
|
||||
|
||||
# If we have a service account, use it for initial bind
|
||||
if AD_BIND_USER and AD_BIND_PASSWORD:
|
||||
conn = ldap3.Connection(server, AD_BIND_USER, AD_BIND_PASSWORD, auto_bind=True)
|
||||
else:
|
||||
conn = ldap3.Connection(server)
|
||||
|
||||
# Search for the user
|
||||
search_filter = f"(sAMAccountName={username})"
|
||||
conn.search(AD_USER_SEARCH_BASE, search_filter, attributes=['mail', 'memberOf', 'displayName'])
|
||||
|
||||
if not conn.entries:
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
user_dn = conn.entries[0].entry_dn
|
||||
user_info = {
|
||||
'username': username,
|
||||
'email': str(conn.entries[0].mail) if conn.entries[0].mail else '',
|
||||
'display_name': str(conn.entries[0].displayName) if conn.entries[0].displayName else username,
|
||||
'groups': [str(group) for group in conn.entries[0].memberOf] if conn.entries[0].memberOf else []
|
||||
}
|
||||
|
||||
# Now try to bind with user credentials to verify password
|
||||
user_conn = ldap3.Connection(server, user_dn, password)
|
||||
if not user_conn.bind():
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
user_conn.unbind()
|
||||
conn.unbind()
|
||||
|
||||
logger.info(f"Successfully authenticated user: {username}")
|
||||
return user_info
|
||||
|
||||
except ldap3.core.exceptions.LDAPException as e:
|
||||
logger.error(f"LDAP error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Authentication service error")
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {str(e)}")
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
def create_jwt_token(user_info: dict) -> str:
|
||||
"""Create JWT token with user information"""
|
||||
expire = datetime.utcnow() + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
||||
payload = {
|
||||
"username": user_info["username"],
|
||||
"email": user_info["email"],
|
||||
"display_name": user_info["display_name"],
|
||||
"groups": user_info["groups"],
|
||||
"exp": expire,
|
||||
"iat": datetime.utcnow()
|
||||
}
|
||||
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
||||
|
||||
def verify_jwt_token(token: str) -> dict:
|
||||
"""Verify and decode JWT token"""
|
||||
try:
|
||||
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
||||
return payload
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.JWTError:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def login_page(request: Request):
|
||||
"""Serve the login page"""
|
||||
return templates.TemplateResponse("login.html", {"request": request})
|
||||
|
||||
@app.get("/dashboard", response_class=HTMLResponse)
|
||||
async def dashboard(request: Request):
|
||||
"""Serve the dashboard page"""
|
||||
return templates.TemplateResponse("dashboard.html", {"request": request})
|
||||
|
||||
@app.post("/auth/login")
|
||||
async def login(username: str = Form(...), password: str = Form(...)):
|
||||
"""Authenticate user and return JWT token"""
|
||||
try:
|
||||
# Verify credentials against AD
|
||||
user_info = verify_ad_credentials(username, password)
|
||||
|
||||
# Create JWT token
|
||||
token = create_jwt_token(user_info)
|
||||
|
||||
# Create response with token in cookie
|
||||
response = JSONResponse({
|
||||
"success": True,
|
||||
"message": "Login successful",
|
||||
"user": {
|
||||
"username": user_info["username"],
|
||||
"email": user_info["email"],
|
||||
"display_name": user_info["display_name"]
|
||||
}
|
||||
})
|
||||
|
||||
# Set HTTP-only cookie with token (works across subdomains)
|
||||
response.set_cookie(
|
||||
key="auth_token",
|
||||
value=token,
|
||||
domain=f".{ALLOWED_DOMAINS[0]}", # Set for all subdomains
|
||||
httponly=True,
|
||||
secure=True, # Use HTTPS in production
|
||||
samesite="lax",
|
||||
max_age=TOKEN_EXPIRE_HOURS * 3600
|
||||
)
|
||||
|
||||
# Also return token for local storage (optional)
|
||||
response.headers["X-Auth-Token"] = token
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
return JSONResponse(
|
||||
status_code=e.status_code,
|
||||
content={"success": False, "message": e.detail}
|
||||
)
|
||||
|
||||
@app.post("/auth/verify")
|
||||
async def verify_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
||||
"""Verify token endpoint for Traefik ForwardAuth"""
|
||||
token = None
|
||||
|
||||
# Get original request information from Traefik headers
|
||||
original_host = request.headers.get("X-Forwarded-Host", request.headers.get("Host", ""))
|
||||
original_proto = request.headers.get("X-Forwarded-Proto", "https")
|
||||
original_uri = request.headers.get("X-Forwarded-Uri", "/")
|
||||
original_url = request.headers.get("X-Original-URL", f"{original_proto}://{original_host}{original_uri}")
|
||||
|
||||
# Check Authorization header first
|
||||
if credentials:
|
||||
token = credentials.credentials
|
||||
else:
|
||||
# Check cookie
|
||||
token = request.cookies.get("auth_token")
|
||||
|
||||
if not token:
|
||||
# Redirect to auth service with return URL
|
||||
auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}"
|
||||
logger.info(f"No token found, redirecting to: {auth_url}")
|
||||
|
||||
# Return 401 with redirect location for Traefik
|
||||
response = JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": "Authentication required", "auth_url": auth_url}
|
||||
)
|
||||
response.headers["Location"] = auth_url
|
||||
return response
|
||||
|
||||
try:
|
||||
payload = verify_jwt_token(token)
|
||||
|
||||
# Check if the request is from an allowed domain
|
||||
is_allowed_domain = any(domain in original_host for domain in ALLOWED_DOMAINS)
|
||||
if not is_allowed_domain:
|
||||
logger.warning(f"Access denied for domain: {original_host}")
|
||||
raise HTTPException(status_code=403, detail="Domain not allowed")
|
||||
|
||||
# Return user info in headers for Traefik
|
||||
headers = {
|
||||
"X-Auth-User": payload["username"],
|
||||
"X-Auth-Email": payload["email"],
|
||||
"X-Auth-Groups": ",".join(payload["groups"]),
|
||||
"X-Auth-Display-Name": payload["display_name"],
|
||||
"X-Auth-Domain": original_host
|
||||
}
|
||||
|
||||
logger.info(f"Authentication successful for {payload['username']} accessing {original_host}")
|
||||
|
||||
return JSONResponse(
|
||||
content={"valid": True, "user": payload["username"], "domain": original_host},
|
||||
headers=headers
|
||||
)
|
||||
|
||||
except HTTPException as e:
|
||||
# On token validation failure, redirect to auth service
|
||||
auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}"
|
||||
logger.warning(f"Token validation failed: {e.detail}, redirecting to: {auth_url}")
|
||||
|
||||
response = JSONResponse(
|
||||
status_code=401,
|
||||
content={"error": e.detail, "auth_url": auth_url}
|
||||
)
|
||||
response.headers["Location"] = auth_url
|
||||
return response
|
||||
|
||||
@app.get("/auth/logout")
|
||||
async def logout():
|
||||
"""Logout endpoint"""
|
||||
response = JSONResponse({"message": "Logged out successfully"})
|
||||
response.delete_cookie("auth_token")
|
||||
return response
|
||||
|
||||
@app.get("/auth/user")
|
||||
async def get_current_user(request: Request):
|
||||
"""Get current user info from token"""
|
||||
token = request.cookies.get("auth_token")
|
||||
if not token:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
payload = verify_jwt_token(token)
|
||||
return {
|
||||
"username": payload["username"],
|
||||
"email": payload["email"],
|
||||
"display_name": payload["display_name"],
|
||||
"groups": payload["groups"]
|
||||
}
|
||||
|
||||
# liveness probe entry point
|
||||
@app.get("/health")
|
||||
async def health_check():
|
||||
"""Health check endpoint"""
|
||||
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
||||
|
||||
# Main entry point
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8080)
|
||||
|
||||
Reference in New Issue
Block a user