"""
FastAPI authentication middleware.
"""
import logging
from typing import Optional, Callable
from fastapi import Request, HTTPException, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.middleware.base import BaseHTTPMiddleware

from ..providers.workos_provider import WorkOSAuthProvider
from ..providers.ory_provider import ORYAuthProvider
from ..models.user_context import UserContext
from ..utils.api_key_manager import APIKeyManager

logger = logging.getLogger(__name__)

# Security scheme for JWT bearer tokens
security = HTTPBearer()


class AuthenticationMiddleware(BaseHTTPMiddleware):
    """Authentication middleware for FastAPI."""

    def __init__(
        self,
        app,
        workos_provider: WorkOSAuthProvider,
        ory_provider: ORYAuthProvider,
        api_key_manager: APIKeyManager,
        exclude_paths: Optional[list] = None
    ):
        """
        Initialize authentication middleware.

        Args:
            app: FastAPI application
            workos_provider: WorkOS authentication provider
            ory_provider: ORY authentication provider
            api_key_manager: API key manager
            exclude_paths: Paths to exclude from authentication
        """
        super().__init__(app)
        self.workos_provider = workos_provider
        self.ory_provider = ory_provider
        self.api_key_manager = api_key_manager
        self.exclude_paths = exclude_paths or [
            "/health",
            "/docs",
            "/openapi.json",
            "/auth/login",
            "/auth/callback"
        ]

    async def dispatch(self, request: Request, call_next: Callable):
        """
        Process request and authenticate user.

        Args:
            request: HTTP request
            call_next: Next middleware in chain

        Returns:
            HTTP response
        """
        # Skip authentication for excluded paths
        if request.url.path in self.exclude_paths:
            return await call_next(request)

        try:
            # Extract authentication credentials
            user_context = await self._authenticate_request(request)

            if not user_context:
                raise HTTPException(
                    status_code=status.HTTP_401_UNAUTHORIZED,
                    detail="Authentication required",
                    headers={"WWW-Authenticate": "Bearer"}
                )

            # Attach user context to request state
            request.state.user = user_context

            # Process request
            response = await call_next(request)
            return response

        except HTTPException:
            raise
        except Exception as e:
            logger.exception(f"Authentication error: {e}")
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail="Internal server error"
            )

    async def _authenticate_request(
        self,
        request: Request
    ) -> Optional[UserContext]:
        """
        Authenticate request using various methods.

        Priority:
        1. JWT token (Authorization: Bearer)
        2. API key (X-API-Key header)
        3. Session cookie

        Args:
            request: HTTP request

        Returns:
            UserContext if authenticated
        """
        # Try JWT token
        auth_header = request.headers.get('Authorization')
        if auth_header and auth_header.startswith('Bearer '):
            token = auth_header.split(' ')[1]
            return await self._authenticate_jwt(token)

        # Try API key
        api_key = request.headers.get('X-API-Key')
        if api_key:
            return await self._authenticate_api_key(api_key)

        # Try session cookie
        session_token = request.cookies.get('session_token')
        if session_token:
            return await self._authenticate_session(session_token)

        return None

    async def _authenticate_jwt(self, token: str) -> Optional[UserContext]:
        """
        Authenticate using JWT token.

        Args:
            token: JWT token

        Returns:
            UserContext if valid
        """
        try:
            # Try WorkOS first (primary IdP)
            claims = await self.workos_provider.validate_token(token)
            if claims:
                return claims.to_user_context()

            # Try ORY token introspection
            token_data = await self.ory_provider.introspect_token(token)
            if token_data:
                return UserContext(
                    user_id=token_data.get('sub'),
                    firm_id=token_data.get('firm_id', ''),
                    role=token_data.get('role', 'client'),
                    email=token_data.get('email')
                )

            return None

        except Exception as e:
            logger.warning(f"JWT authentication failed: {e}")
            return None

    async def _authenticate_api_key(
        self,
        api_key: str
    ) -> Optional[UserContext]:
        """
        Authenticate using API key.

        Args:
            api_key: API key

        Returns:
            UserContext if valid
        """
        try:
            key_info = await self.api_key_manager.validate_api_key(api_key)
            if not key_info:
                return None

            # Create user context from API key
            return UserContext(
                user_id=key_info.key_id,
                firm_id='',  # API keys may not have firm context
                role='ai_agent',  # Service account role
                permissions=key_info.permissions,
                metadata={'api_key_name': key_info.name}
            )

        except Exception as e:
            logger.warning(f"API key authentication failed: {e}")
            return None

    async def _authenticate_session(
        self,
        session_token: str
    ) -> Optional[UserContext]:
        """
        Authenticate using session token.

        Args:
            session_token: Session token

        Returns:
            UserContext if valid
        """
        try:
            return await self.ory_provider.get_session(session_token)

        except Exception as e:
            logger.warning(f"Session authentication failed: {e}")
            return None


async def get_current_user(request: Request) -> UserContext:
    """
    Dependency to get current authenticated user.

    Args:
        request: HTTP request

    Returns:
        UserContext

    Raises:
        HTTPException if not authenticated
    """
    if not hasattr(request.state, 'user'):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail="Authentication required"
        )

    return request.state.user


async def require_role(*allowed_roles: str):
    """
    Dependency to require specific roles.

    Args:
        allowed_roles: Allowed role names

    Returns:
        Dependency function
    """
    async def check_role(user: UserContext = get_current_user):
        if user.role not in allowed_roles:
            raise HTTPException(
                status_code=status.HTTP_403_FORBIDDEN,
                detail=f"Role required: {', '.join(allowed_roles)}"
            )
        return user

    return check_role
