"""
WorkOS authentication provider (Primary IdP).
"""
import os
import logging
from typing import Optional
from datetime import datetime, timedelta
import jwt
import httpx
from jwt import PyJWKClient

from ..models.user_context import (
    AuthResult, UserContext, TokenClaims,
    MFAChallenge, VerificationResult, UserRole
)

logger = logging.getLogger(__name__)


class WorkOSAuthProvider:
    """Primary authentication via WorkOS."""

    def __init__(
        self,
        api_key: Optional[str] = None,
        client_id: Optional[str] = None,
        redirect_uri: Optional[str] = None
    ):
        """
        Initialize WorkOS provider.

        Args:
            api_key: WorkOS API key (from Azure Key Vault)
            client_id: WorkOS client ID
            redirect_uri: OAuth redirect URI
        """
        self.api_key = api_key or os.getenv('WORKOS_API_KEY')
        self.client_id = client_id or os.getenv('WORKOS_CLIENT_ID')
        self.redirect_uri = redirect_uri or os.getenv('WORKOS_REDIRECT_URI')

        if not all([self.api_key, self.client_id, self.redirect_uri]):
            raise ValueError("WorkOS credentials not configured")

        self.base_url = "https://api.workos.com"
        self.jwks_client = PyJWKClient(
            "https://api.workos.com/.well-known/jwks.json"
        )

        # HTTP client with retry logic
        self.client = httpx.AsyncClient(
            base_url=self.base_url,
            headers={
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            },
            timeout=30.0
        )

    async def authenticate(
        self,
        authorization_code: str,
        redirect_uri: Optional[str] = None
    ) -> AuthResult:
        """
        Exchange authorization code for tokens.

        Flow:
        1. User signs in via WorkOS-hosted login
        2. WorkOS returns authorization code
        3. Exchange code for ID token + access token
        4. Validate tokens and extract claims
        5. Create internal session

        Args:
            authorization_code: OAuth authorization code
            redirect_uri: Optional override for redirect URI

        Returns:
            AuthResult with user context and tokens
        """
        try:
            # Exchange code for tokens
            response = await self.client.post(
                "/oauth/token",
                json={
                    "grant_type": "authorization_code",
                    "code": authorization_code,
                    "client_id": self.client_id,
                    "redirect_uri": redirect_uri or self.redirect_uri
                }
            )

            if response.status_code != 200:
                logger.error(f"Token exchange failed: {response.text}")
                return AuthResult(
                    success=False,
                    error="authentication_failed",
                    error_description="Failed to exchange authorization code"
                )

            data = response.json()
            access_token = data.get('access_token')
            refresh_token = data.get('refresh_token')
            id_token = data.get('id_token')

            # Validate ID token
            claims = await self.validate_token(id_token)
            if not claims:
                return AuthResult(
                    success=False,
                    error="invalid_token",
                    error_description="Token validation failed"
                )

            # Convert claims to user context
            user_context = claims.to_user_context()

            logger.info(
                f"User authenticated successfully: {user_context.user_id}"
            )

            return AuthResult(
                success=True,
                user_context=user_context,
                access_token=access_token,
                refresh_token=refresh_token,
                expires_in=data.get('expires_in', 3600)
            )

        except Exception as e:
            logger.exception(f"Authentication error: {e}")
            return AuthResult(
                success=False,
                error="internal_error",
                error_description=str(e)
            )

    async def validate_token(self, token: str) -> Optional[TokenClaims]:
        """
        Validate JWT token from WorkOS.

        Validations:
        - Signature verification (RS256)
        - Issuer verification (workos.com)
        - Expiration check
        - Audience validation
        - Revocation check via JWKS rotation

        Args:
            token: JWT token to validate

        Returns:
            TokenClaims if valid, None otherwise
        """
        try:
            # Get signing key
            signing_key = self.jwks_client.get_signing_key_from_jwt(token)

            # Decode and validate token
            payload = jwt.decode(
                token,
                signing_key.key,
                algorithms=["RS256"],
                audience=self.client_id,
                issuer="https://api.workos.com"
            )

            # Extract claims
            claims = TokenClaims(
                sub=payload['sub'],
                iss=payload['iss'],
                aud=payload['aud'],
                exp=payload['exp'],
                iat=payload['iat'],
                firm_id=payload.get('firm_id', ''),
                role=payload.get('role', UserRole.CLIENT.value),
                email=payload.get('email'),
                metadata=payload.get('metadata', {})
            )

            # Check expiration
            if claims.is_expired():
                logger.warning(f"Token expired for user: {claims.sub}")
                return None

            return claims

        except jwt.InvalidTokenError as e:
            logger.warning(f"Token validation failed: {e}")
            return None
        except Exception as e:
            logger.exception(f"Token validation error: {e}")
            return None

    async def refresh_token(self, refresh_token: str) -> Optional[AuthResult]:
        """
        Refresh expired access token.

        Args:
            refresh_token: Refresh token

        Returns:
            AuthResult with new tokens
        """
        try:
            response = await self.client.post(
                "/oauth/token",
                json={
                    "grant_type": "refresh_token",
                    "refresh_token": refresh_token,
                    "client_id": self.client_id
                }
            )

            if response.status_code != 200:
                logger.error(f"Token refresh failed: {response.text}")
                return None

            data = response.json()

            # Validate new token
            claims = await self.validate_token(data['access_token'])
            if not claims:
                return None

            return AuthResult(
                success=True,
                user_context=claims.to_user_context(),
                access_token=data['access_token'],
                refresh_token=data.get('refresh_token'),
                expires_in=data.get('expires_in', 3600)
            )

        except Exception as e:
            logger.exception(f"Token refresh error: {e}")
            return None

    async def revoke_session(self, user_id: str, session_id: str):
        """
        Revoke user session (logout).

        Args:
            user_id: User ID
            session_id: Session ID to revoke
        """
        try:
            # WorkOS doesn't have direct session revocation
            # We rely on Redis session store for revocation
            logger.info(f"Session revoked: {session_id} for user {user_id}")

        except Exception as e:
            logger.exception(f"Session revocation error: {e}")

    async def require_mfa_verification(
        self,
        user_id: str,
        session_id: str
    ) -> MFAChallenge:
        """
        Initiate MFA challenge.

        Args:
            user_id: User ID
            session_id: Session ID

        Returns:
            MFAChallenge with challenge ID and available methods
        """
        try:
            response = await self.client.post(
                "/mfa/challenges",
                json={
                    "user_id": user_id,
                    "session_id": session_id
                }
            )

            if response.status_code != 200:
                raise Exception(f"MFA challenge failed: {response.text}")

            data = response.json()

            return MFAChallenge(
                challenge_id=data['id'],
                user_id=user_id,
                methods=data['methods'],
                expires_at=datetime.fromisoformat(data['expires_at'])
            )

        except Exception as e:
            logger.exception(f"MFA challenge error: {e}")
            raise

    async def verify_mfa_code(
        self,
        challenge_id: str,
        code: str,
        remember_device: bool = False
    ) -> VerificationResult:
        """
        Verify MFA code.

        TOTP: Time-based 6-digit code
        SMS: 6-digit code sent via Twilio
        WebAuthn: FIDO2 challenge-response

        Args:
            challenge_id: Challenge ID
            code: Verification code
            remember_device: Whether to remember this device

        Returns:
            VerificationResult
        """
        try:
            response = await self.client.post(
                f"/mfa/challenges/{challenge_id}/verify",
                json={
                    "code": code,
                    "remember_device": remember_device
                }
            )

            if response.status_code != 200:
                return VerificationResult(
                    success=False,
                    error="Invalid verification code"
                )

            data = response.json()

            return VerificationResult(
                success=True,
                session_id=data.get('session_id')
            )

        except Exception as e:
            logger.exception(f"MFA verification error: {e}")
            return VerificationResult(
                success=False,
                error=str(e)
            )

    async def close(self):
        """Close HTTP client."""
        await self.client.aclose()
