299 lines
11 KiB
Python
299 lines
11 KiB
Python
from fastapi import FastAPI, HTTPException, Request, Response, Depends, Form
|
|
from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
import jwt
|
|
import bcrypt
|
|
import ldap3
|
|
from datetime import datetime, timedelta
|
|
import os
|
|
import logging
|
|
from typing import Optional
|
|
from pydantic import BaseModel
|
|
"""
|
|
This is an authentication service using FastAPI that verifies user credentials against Active Directory (AD)
|
|
and issues JWT tokens for authenticated users. It supports cross-domain authentication and is designed to work
|
|
with Traefik as a reverse proxy.
|
|
This will be front end for apps
|
|
"""
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Configuration
|
|
JWT_SECRET = os.getenv("JWT_SECRET", "your-super-secret-key-change-this")
|
|
JWT_ALGORITHM = "HS256"
|
|
TOKEN_EXPIRE_HOURS = int(os.getenv("TOKEN_EXPIRE_HOURS", "8"))
|
|
|
|
# Domain configuration for cross-domain auth
|
|
ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",")
|
|
AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld")
|
|
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",")
|
|
|
|
app = FastAPI(title="Authentication Service", description="AD Authentication with JWT tokens")
|
|
|
|
# Add CORS middleware for cross-domain authentication
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"], # In production, specify exact domains
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# Domain configuration for cross-domain auth
|
|
ALLOWED_DOMAINS = os.getenv("ALLOWED_DOMAINS", "domain.tld").split(",")
|
|
AUTH_DOMAIN = os.getenv("AUTH_DOMAIN", "auth.domain.tld")
|
|
CORS_ORIGINS = os.getenv("CORS_ORIGINS", "https://*.domain.tld").split(",")
|
|
|
|
# Active Directory Configuration
|
|
AD_SERVER = os.getenv("AD_SERVER", "ldap://your-ad-server.com")
|
|
AD_BASE_DN = os.getenv("AD_BASE_DN", "DC=yourdomain,DC=com")
|
|
AD_USER_SEARCH_BASE = os.getenv("AD_USER_SEARCH_BASE", "CN=Users,DC=yourdomain,DC=com")
|
|
AD_BIND_USER = os.getenv("AD_BIND_USER", "ReadUser") # Service account for LDAP bind
|
|
AD_BIND_PASSWORD = os.getenv("AD_BIND_PASSWORD", "")
|
|
|
|
# Setup templates and static files
|
|
templates = Jinja2Templates(directory="templates")
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
security = HTTPBearer(auto_error=False)
|
|
|
|
class LoginRequest(BaseModel):
|
|
username: str
|
|
password: str
|
|
|
|
class TokenData(BaseModel):
|
|
username: str
|
|
email: Optional[str] = None
|
|
groups: list = []
|
|
exp: datetime
|
|
|
|
def verify_ad_credentials(username: str, password: str) -> dict:
|
|
"""
|
|
Verify credentials against Active Directory
|
|
Returns user info if valid, raises exception if invalid
|
|
"""
|
|
try:
|
|
# Connect to AD server
|
|
server = ldap3.Server(AD_SERVER, get_info=ldap3.ALL)
|
|
|
|
# If we have a service account, use it for initial bind
|
|
if AD_BIND_USER and AD_BIND_PASSWORD:
|
|
conn = ldap3.Connection(server, AD_BIND_USER, AD_BIND_PASSWORD, auto_bind=True)
|
|
else:
|
|
conn = ldap3.Connection(server)
|
|
|
|
# Search for the user
|
|
search_filter = f"(sAMAccountName={username})"
|
|
conn.search(AD_USER_SEARCH_BASE, search_filter, attributes=['mail', 'memberOf', 'displayName'])
|
|
|
|
if not conn.entries:
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
user_dn = conn.entries[0].entry_dn
|
|
user_info = {
|
|
'username': username,
|
|
'email': str(conn.entries[0].mail) if conn.entries[0].mail else '',
|
|
'display_name': str(conn.entries[0].displayName) if conn.entries[0].displayName else username,
|
|
'groups': [str(group) for group in conn.entries[0].memberOf] if conn.entries[0].memberOf else []
|
|
}
|
|
|
|
# Now try to bind with user credentials to verify password
|
|
user_conn = ldap3.Connection(server, user_dn, password)
|
|
if not user_conn.bind():
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
user_conn.unbind()
|
|
conn.unbind()
|
|
|
|
logger.info(f"Successfully authenticated user: {username}")
|
|
return user_info
|
|
|
|
except ldap3.core.exceptions.LDAPException as e:
|
|
logger.error(f"LDAP error: {str(e)}")
|
|
raise HTTPException(status_code=500, detail="Authentication service error")
|
|
except Exception as e:
|
|
logger.error(f"Authentication error: {str(e)}")
|
|
raise HTTPException(status_code=401, detail="Invalid credentials")
|
|
|
|
def create_jwt_token(user_info: dict) -> str:
|
|
"""Create JWT token with user information"""
|
|
expire = datetime.utcnow() + timedelta(hours=TOKEN_EXPIRE_HOURS)
|
|
payload = {
|
|
"username": user_info["username"],
|
|
"email": user_info["email"],
|
|
"display_name": user_info["display_name"],
|
|
"groups": user_info["groups"],
|
|
"exp": expire,
|
|
"iat": datetime.utcnow()
|
|
}
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
def verify_jwt_token(token: str) -> dict:
|
|
"""Verify and decode JWT token"""
|
|
try:
|
|
payload = jwt.decode(token, JWT_SECRET, algorithms=[JWT_ALGORITHM])
|
|
return payload
|
|
except jwt.ExpiredSignatureError:
|
|
raise HTTPException(status_code=401, detail="Token expired")
|
|
except jwt.JWTError:
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
@app.get("/", response_class=HTMLResponse)
|
|
async def login_page(request: Request):
|
|
"""Serve the login page"""
|
|
return templates.TemplateResponse("login.html", {"request": request})
|
|
|
|
@app.get("/dashboard", response_class=HTMLResponse)
|
|
async def dashboard(request: Request):
|
|
"""Serve the dashboard page"""
|
|
return templates.TemplateResponse("dashboard.html", {"request": request})
|
|
|
|
@app.post("/auth/login")
|
|
async def login(username: str = Form(...), password: str = Form(...)):
|
|
"""Authenticate user and return JWT token"""
|
|
try:
|
|
# Verify credentials against AD
|
|
user_info = verify_ad_credentials(username, password)
|
|
|
|
# Create JWT token
|
|
token = create_jwt_token(user_info)
|
|
|
|
# Create response with token in cookie
|
|
response = JSONResponse({
|
|
"success": True,
|
|
"message": "Login successful",
|
|
"user": {
|
|
"username": user_info["username"],
|
|
"email": user_info["email"],
|
|
"display_name": user_info["display_name"]
|
|
}
|
|
})
|
|
|
|
# Set HTTP-only cookie with token (works across subdomains)
|
|
response.set_cookie(
|
|
key="auth_token",
|
|
value=token,
|
|
domain=f".{ALLOWED_DOMAINS[0]}", # Set for all subdomains
|
|
httponly=True,
|
|
secure=True, # Use HTTPS in production
|
|
samesite="lax",
|
|
max_age=TOKEN_EXPIRE_HOURS * 3600
|
|
)
|
|
|
|
# Also return token for local storage (optional)
|
|
response.headers["X-Auth-Token"] = token
|
|
|
|
return response
|
|
|
|
except HTTPException as e:
|
|
return JSONResponse(
|
|
status_code=e.status_code,
|
|
content={"success": False, "message": e.detail}
|
|
)
|
|
|
|
@app.post("/auth/verify")
|
|
async def verify_token(request: Request, credentials: HTTPAuthorizationCredentials = Depends(security)):
|
|
"""Verify token endpoint for Traefik ForwardAuth"""
|
|
token = None
|
|
|
|
# Get original request information from Traefik headers
|
|
original_host = request.headers.get("X-Forwarded-Host", request.headers.get("Host", ""))
|
|
original_proto = request.headers.get("X-Forwarded-Proto", "https")
|
|
original_uri = request.headers.get("X-Forwarded-Uri", "/")
|
|
original_url = request.headers.get("X-Original-URL", f"{original_proto}://{original_host}{original_uri}")
|
|
|
|
# Check Authorization header first
|
|
if credentials:
|
|
token = credentials.credentials
|
|
else:
|
|
# Check cookie
|
|
token = request.cookies.get("auth_token")
|
|
|
|
if not token:
|
|
# Redirect to auth service with return URL
|
|
auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}"
|
|
logger.info(f"No token found, redirecting to: {auth_url}")
|
|
|
|
# Return 401 with redirect location for Traefik
|
|
response = JSONResponse(
|
|
status_code=401,
|
|
content={"error": "Authentication required", "auth_url": auth_url}
|
|
)
|
|
response.headers["Location"] = auth_url
|
|
return response
|
|
|
|
try:
|
|
payload = verify_jwt_token(token)
|
|
|
|
# Check if the request is from an allowed domain
|
|
is_allowed_domain = any(domain in original_host for domain in ALLOWED_DOMAINS)
|
|
if not is_allowed_domain:
|
|
logger.warning(f"Access denied for domain: {original_host}")
|
|
raise HTTPException(status_code=403, detail="Domain not allowed")
|
|
|
|
# Return user info in headers for Traefik
|
|
headers = {
|
|
"X-Auth-User": payload["username"],
|
|
"X-Auth-Email": payload["email"],
|
|
"X-Auth-Groups": ",".join(payload["groups"]),
|
|
"X-Auth-Display-Name": payload["display_name"],
|
|
"X-Auth-Domain": original_host
|
|
}
|
|
|
|
logger.info(f"Authentication successful for {payload['username']} accessing {original_host}")
|
|
|
|
return JSONResponse(
|
|
content={"valid": True, "user": payload["username"], "domain": original_host},
|
|
headers=headers
|
|
)
|
|
|
|
except HTTPException as e:
|
|
# On token validation failure, redirect to auth service
|
|
auth_url = f"https://{AUTH_DOMAIN}/?return_url={original_url}"
|
|
logger.warning(f"Token validation failed: {e.detail}, redirecting to: {auth_url}")
|
|
|
|
response = JSONResponse(
|
|
status_code=401,
|
|
content={"error": e.detail, "auth_url": auth_url}
|
|
)
|
|
response.headers["Location"] = auth_url
|
|
return response
|
|
|
|
@app.get("/auth/logout")
|
|
async def logout():
|
|
"""Logout endpoint"""
|
|
response = JSONResponse({"message": "Logged out successfully"})
|
|
response.delete_cookie("auth_token")
|
|
return response
|
|
|
|
@app.get("/auth/user")
|
|
async def get_current_user(request: Request):
|
|
"""Get current user info from token"""
|
|
token = request.cookies.get("auth_token")
|
|
if not token:
|
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
|
|
|
payload = verify_jwt_token(token)
|
|
return {
|
|
"username": payload["username"],
|
|
"email": payload["email"],
|
|
"display_name": payload["display_name"],
|
|
"groups": payload["groups"]
|
|
}
|
|
|
|
# liveness probe entry point
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint"""
|
|
return {"status": "healthy", "timestamp": datetime.utcnow().isoformat()}
|
|
|
|
# Main entry point
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="0.0.0.0", port=8080)
|
|
|