from fastapi import FastAPI, HTTPException, Request, Response, Depends, Form from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware import jwt import bcrypt import ldap3 from datetime import datetime, timedelta import os import logging from typing import Optional from pydantic import BaseModel """ This is an authentication service using FastAPI that verifies user credentials against Active Directory (AD) and issues JWT tokens for authenticated users. It supports cross-domain authentication and is designed to work with Traefik as a reverse proxy. This will be front end for apps """ # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Configuration JWT_SECRET = os.getenv("JWT_SECRET", "your-super-secret-key-change-this") JWT_ALGORITHM = "HS256" TOKEN_EXPIRE_HOURS = int(os.getenv("TOKEN_EXPIRE_HOURS", "8")) # Domain configuration for cross-domain auth ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",") AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld") CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",") app = FastAPI(title="Authentication Service", description="AD Authentication with JWT tokens") # Add CORS middleware for cross-domain authentication app.add_middleware( CORSMiddleware, allow_origins=["*"], # In production, specify exact domains allow_credentials=True, allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], ) # Domain configuration for cross-domain auth ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",") AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld") CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",") # Active Directory Configuration AD_SERVER = os.getenv("AD_SERVER", "ldap://your-ad-server.com") AD_BASE_DN = os.getenv("AD_BASE_DN", "DC=yourdomain,DC=com") AD_USER_SEARCH_BASE = os.getenv("AD_USER_SEARCH_BASE", "CN=Users,DC=yourdomain,DC=com") AD_BIND_USER = os.getenv("AD_BIND_USER", "ReadUser") # Service account for LDAP bind AD_BIND_PASSWORD = os.getenv("AD_BIND_PASSWORD", "") # Setup templates and static files templates = Jinja2Templates(directory="templates") app.mount("/static", StaticFiles(directory="static"), name="static") security = HTTPBearer(auto_error=False) class LoginRequest(BaseModel): username: str password: str class TokenData(BaseModel): username: str email: Optional[str] = None groups: list = [] exp: datetime def verify_ad_credentials(username: str, password: str) -> dict: """ Verify credentials against Active Directory Returns user info if valid, raises exception if invalid """ try: # Connect to AD server server = ldap3.Server(AD_SERVER, get_info=ldap3.ALL) # If we have a service account, use it for initial bind if AD_BIND_USER and AD_BIND_PASSWORD: conn = ldap3.Connection(server, AD_BIND_USER, AD_BIND_PASSWORD, auto_bind=True) else: conn = ldap3.Connection(server) # Search for the user search_filter = f"(sAMAccountName={username})" conn.search(AD_USER_SEARCH_BASE, search_filter, attributes=['mail', 'memberOf', 'displayName']) if not conn.entries: raise HTTPException(status_code=401, detail="Invalid credentials") user_dn = conn.entries[0].entry_dn user_info = { 'username': username, 'email': str(conn.entries[0].mail) if conn.entries[0].mail else '', 'display_name': str(conn.entries[0].displayName) if conn.entries[0].displayName else username, 'groups': [str(group) for group in conn.entries[0].memberOf] if conn.entries[0].memberOf else [] } # Now try to bind with user credentials to verify password user_conn = ldap3.Connection(server, user_dn, password) if not user_conn.bind(): raise HTTPException(status_code=401, detail="Invalid credentials") user_conn.unbind() conn.unbind() logger.info(f"Successfully authenticated user: {username}") return user_info except ldap3.core.exceptions.LDAPException as e: logger.error(f"LDAP error: {str(e)}") raise HTTPException(status_code=500, detail="Authentication service error") except Exception as e: logger.error(f"Authentication error: {str(e)}") raise HTTPException(status_code=401, detail="Invalid credentials") def create_jwt_token(user_info: dict) -> str: """Create JWT token with user information""" expire = datetime.utcnow() + timedelta(hours=TOKEN_EXPIRE_HOURS) payload = { "username": user_info["username"], "email": user_info["email"], "display_name": user_info["display_name"], "groups": user_info["groups"], "exp": expire, "iat": datetime.utcnow() } return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM) def verify_jwt_token(token: str) -> dict: """Verify and decode JWT token""" try: payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM]) return payload except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.JWTError: raise HTTPException(status_code=401, detail="Invalid token") @app.get("/", response_class=HTMLResponse) async def login_page(request: Request): """Serve the login page""" return templates.TemplateResponse("login.html", {"request": request}) @app.get("/dashboard", response_class=HTMLResponse) async def dashboard(request: Request): """Serve the dashboard page""" return templates.TemplateResponse("dashboard.html", {"request": request}) @app.post("/auth/login") async def login(username: str = Form(...), password: str = Form(...)): """Authenticate user and return JWT token""" try: # Verify credentials against AD user_info = verify_ad_credentials(username, password) # Create JWT token token = create_jwt_token(user_info) # Create response with token in cookie response = JSONResponse({ "success": True, "message": "Login successful", "user": { "username": user_info["username"], "email": user_info["email"], "display_name": user_info["display_name"] } }) # Set HTTP-only cookie with token (works across subdomains) response.set_cookie( key="auth_token", value=token, domain=f".{ALLOWED_DOMAINS[0]}", # Set for all subdomains httponly=True, secure=True, # Use HTTPS in production samesite="lax", max_age=TOKEN_EXPIRE_HOURS * 3600 ) # Also return token for local storage (optional) response.headers["X-Auth-Token"] = token return response except HTTPException as e: return JSONResponse( status_code=e.status_code, content={"success": False, "message": e.detail} ) @app.post("/auth/verify") async def verify_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)): """Verify token endpoint for Traefik ForwardAuth""" token = None # Get original request information from Traefik headers original_host = request.headers.get("X-Forwarded-Host", request.headers.get("Host", "")) original_proto = request.headers.get("X-Forwarded-Proto", "https") original_uri = request.headers.get("X-Forwarded-Uri", "/") original_url = request.headers.get("X-Original-URL", f"{original_proto}://{original_host}{original_uri}") # Check Authorization header first if credentials: token = credentials.credentials else: # Check cookie token = request.cookies.get("auth_token") if not token: # Redirect to auth service with return URL auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}" logger.info(f"No token found, redirecting to: {auth_url}") # Return 401 with redirect location for Traefik response = JSONResponse( status_code=401, content={"error": "Authentication required", "auth_url": auth_url} ) response.headers["Location"] = auth_url return response try: payload = verify_jwt_token(token) # Check if the request is from an allowed domain is_allowed_domain = any(domain in original_host for domain in ALLOWED_DOMAINS) if not is_allowed_domain: logger.warning(f"Access denied for domain: {original_host}") raise HTTPException(status_code=403, detail="Domain not allowed") # Return user info in headers for Traefik headers = { "X-Auth-User": payload["username"], "X-Auth-Email": payload["email"], "X-Auth-Groups": ",".join(payload["groups"]), "X-Auth-Display-Name": payload["display_name"], "X-Auth-Domain": original_host } logger.info(f"Authentication successful for {payload['username']} accessing {original_host}") return JSONResponse( content={"valid": True, "user": payload["username"], "domain": original_host}, headers=headers ) except HTTPException as e: # On token validation failure, redirect to auth service auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}" logger.warning(f"Token validation failed: {e.detail}, redirecting to: {auth_url}") response = JSONResponse( status_code=401, content={"error": e.detail, "auth_url": auth_url} ) response.headers["Location"] = auth_url return response @app.get("/auth/logout") async def logout(): """Logout endpoint""" response = JSONResponse({"message": "Logged out successfully"}) response.delete_cookie("auth_token") return response @app.get("/auth/user") async def get_current_user(request: Request): """Get current user info from token""" token = request.cookies.get("auth_token") if not token: raise HTTPException(status_code=401, detail="Not authenticated") payload = verify_jwt_token(token) return { "username": payload["username"], "email": payload["email"], "display_name": payload["display_name"], "groups": payload["groups"] } # liveness probe entry point @app.get("/health") async def health_check(): """Health check endpoint""" return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()} # Main entry point if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8080)