Skip to content

presentation.routers.oauth_callbacks

src.presentation.routers.oauth_callbacks

OAuth callback router for provider authentication.

Handles OAuth 2.0 Authorization Code callbacks from providers. These endpoints are external-facing (dictated by provider redirect URI requirements) and not part of the versioned API.

Flow
  1. User initiates OAuth via frontend → Backend generates auth URL with state
  2. User authorizes at provider → Provider redirects to callback with code
  3. This router exchanges code for tokens → Creates provider connection
Security
  • CSRF protection via state parameter (stored in Redis)
  • State contains user_id + provider_slug + timestamp
  • State expires after 10 minutes (prevent replay attacks)

Registered Callback URLs (Schwab Developer Portal): - https://127.0.0.1:8182/oauth/schwab/callback (local standalone) - https://dashtam.local/oauth/schwab/callback (local via Traefik)

Reference
  • docs/architecture/provider-oauth-architecture.md

Classes

Functions

schwab_oauth_callback async

schwab_oauth_callback(
    request: Request,
    code: Annotated[
        str | None, Query(description="Authorization code")
    ] = None,
    state: Annotated[
        str | None, Query(description="CSRF state token")
    ] = None,
    error: Annotated[
        str | None, Query(description="OAuth error code")
    ] = None,
    error_description: Annotated[
        str | None,
        Query(description="OAuth error description"),
    ] = None,
    cache: CacheProtocol = Depends(get_cache),
    handler: ConnectProviderHandler = Depends(
        handler_factory(ConnectProviderHandler)
    ),
    encryption_service: EncryptionService = Depends(
        get_encryption_service
    ),
    provider_repo: ProviderRepository = Depends(
        get_provider_repository
    ),
) -> HTMLResponse

Handle Schwab OAuth 2.0 callback.

This endpoint is called by Schwab after user authorizes the application. It exchanges the authorization code for tokens and creates the connection.

Query Parameters

code: Authorization code from Schwab (on success). state: CSRF token (must match stored session state). error: OAuth error code (on user denial or error). error_description: Human-readable error description.

Returns:

Name Type Description
HTMLResponse HTMLResponse

Success or error page for user.

Flow
  1. Validate state parameter (CSRF protection)
  2. Handle OAuth errors from provider
  3. Exchange authorization code for tokens
  4. Encrypt tokens for storage
  5. Create provider connection via command handler
  6. Return success/error HTML to user
Source code in src/presentation/routers/oauth_callbacks.py
@oauth_router.get(
    "/oauth/schwab/callback",
    response_class=HTMLResponse,
    summary="Schwab OAuth callback",
    description="Handle OAuth 2.0 Authorization Code callback from Schwab.",
    responses={
        200: {"description": "Connection successful"},
        400: {"description": "OAuth error or invalid state"},
        500: {"description": "Internal error during token exchange"},
    },
)
async def schwab_oauth_callback(
    request: Request,
    code: Annotated[str | None, Query(description="Authorization code")] = None,
    state: Annotated[str | None, Query(description="CSRF state token")] = None,
    error: Annotated[str | None, Query(description="OAuth error code")] = None,
    error_description: Annotated[
        str | None, Query(description="OAuth error description")
    ] = None,
    cache: CacheProtocol = Depends(get_cache),
    handler: ConnectProviderHandler = Depends(handler_factory(ConnectProviderHandler)),
    encryption_service: EncryptionService = Depends(get_encryption_service),
    provider_repo: ProviderRepository = Depends(get_provider_repository),
) -> HTMLResponse:
    """Handle Schwab OAuth 2.0 callback.

    This endpoint is called by Schwab after user authorizes the application.
    It exchanges the authorization code for tokens and creates the connection.

    Query Parameters:
        code: Authorization code from Schwab (on success).
        state: CSRF token (must match stored session state).
        error: OAuth error code (on user denial or error).
        error_description: Human-readable error description.

    Returns:
        HTMLResponse: Success or error page for user.

    Flow:
        1. Validate state parameter (CSRF protection)
        2. Handle OAuth errors from provider
        3. Exchange authorization code for tokens
        4. Encrypt tokens for storage
        5. Create provider connection via command handler
        6. Return success/error HTML to user
    """
    provider_slug = "schwab"

    # Step 1: Handle OAuth errors from provider
    if error:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Authorization Denied",
                error_message=error_description or error,
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 2: Validate required parameters
    if not code:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Missing Authorization Code",
                error_message="No authorization code received from provider.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    if not state:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Missing State Parameter",
                error_message="State parameter is required for security.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 3: Validate state (CSRF protection)
    state_data = await _get_oauth_state_data(cache, state)
    if state_data is None:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Invalid or Expired State",
                error_message="Session expired. Please start the connection process again.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Extract user info from state
    user_id = UUID(state_data["user_id"])
    stored_provider = state_data.get("provider_slug")

    # Verify provider matches
    if stored_provider != provider_slug:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Mismatch",
                error_message="State does not match expected provider.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 4: Get provider and verify OAuth capability
    try:
        provider = get_provider(provider_slug)
    except ValueError as e:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not Found",
                error_message=str(e),
            ),
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        )

    if not is_oauth_provider(provider):
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not OAuth",
                error_message=f"Provider '{provider_slug}' does not support OAuth.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Exchange authorization code for tokens (type narrowed to OAuthProviderProtocol)
    token_result = await provider.exchange_code_for_tokens(code)

    match token_result:
        case Failure(error=provider_error):
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Token Exchange Failed",
                    error_message=provider_error.message,
                ),
                status_code=status.HTTP_400_BAD_REQUEST,
            )
        case Success(value=tokens):
            pass  # Continue with tokens

    # Step 5: Encrypt tokens for storage
    token_data = {
        "access_token": tokens.access_token,
        "refresh_token": tokens.refresh_token,
        "token_type": tokens.token_type,
        "scope": tokens.scope,
    }
    encryption_result = encryption_service.encrypt(token_data)

    match encryption_result:
        case Failure():
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Encryption Failed",
                    error_message="Unable to secure credentials.",
                ),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
        case Success(value=encrypted_data):
            pass  # Continue with encrypted data

    # Step 6: Calculate token expiration
    expires_at = datetime.now(UTC) + timedelta(seconds=tokens.expires_in)

    # Step 7: Look up provider from database
    provider_entity = await provider_repo.find_by_slug(provider_slug)
    if provider_entity is None:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not Configured",
                error_message=f"Provider '{provider_slug}' is not registered in the system.",
            ),
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        )

    # Step 8: Create provider connection via command
    credentials = ProviderCredentials(
        encrypted_data=encrypted_data,
        credential_type=CredentialType.OAUTH2,
        expires_at=expires_at,
    )

    connect_command = ConnectProvider(
        user_id=user_id,
        provider_id=provider_entity.id,
        provider_slug=provider_slug,
        credentials=credentials,
        alias=state_data.get("alias"),  # Optional alias from state
    )

    connect_result = await handler.handle(connect_command)

    match connect_result:
        case Failure(error=connect_error):
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Connection Failed",
                    error_message=connect_error,
                ),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
        case Success():
            return HTMLResponse(
                content=_create_success_html(provider_slug),
                status_code=status.HTTP_200_OK,
            )

oauth_callback_dynamic async

oauth_callback_dynamic(
    provider_slug: str,
    request: Request,
    code: Annotated[
        str | None, Query(description="Authorization code")
    ] = None,
    state: Annotated[
        str | None, Query(description="CSRF state token")
    ] = None,
    error: Annotated[
        str | None, Query(description="OAuth error code")
    ] = None,
    error_description: Annotated[
        str | None,
        Query(description="OAuth error description"),
    ] = None,
    cache: CacheProtocol = Depends(get_cache),
    handler: ConnectProviderHandler = Depends(
        handler_factory(ConnectProviderHandler)
    ),
    encryption_service: EncryptionService = Depends(
        get_encryption_service
    ),
    provider_repo: ProviderRepository = Depends(
        get_provider_repository
    ),
) -> HTMLResponse

Handle OAuth 2.0 callback for any provider slug.

Mirrors logic of the Schwab-specific route, but uses the dynamic provider_slug path parameter to resolve provider and validate state.

Source code in src/presentation/routers/oauth_callbacks.py
@oauth_router.get(
    "/oauth/{provider_slug}/callback",
    response_class=HTMLResponse,
    summary="OAuth callback (dynamic)",
    description="Handle OAuth 2.0 Authorization Code callback for any configured provider.",
    responses={
        200: {"description": "Connection successful"},
        400: {"description": "OAuth error or invalid state"},
        500: {"description": "Internal error during token exchange"},
    },
)
async def oauth_callback_dynamic(
    provider_slug: str,
    request: Request,
    code: Annotated[str | None, Query(description="Authorization code")] = None,
    state: Annotated[str | None, Query(description="CSRF state token")] = None,
    error: Annotated[str | None, Query(description="OAuth error code")] = None,
    error_description: Annotated[
        str | None, Query(description="OAuth error description")
    ] = None,
    cache: CacheProtocol = Depends(get_cache),
    handler: ConnectProviderHandler = Depends(handler_factory(ConnectProviderHandler)),
    encryption_service: EncryptionService = Depends(get_encryption_service),
    provider_repo: ProviderRepository = Depends(get_provider_repository),
) -> HTMLResponse:
    """Handle OAuth 2.0 callback for any provider slug.

    Mirrors logic of the Schwab-specific route, but uses the dynamic
    provider_slug path parameter to resolve provider and validate state.
    """
    # Step 1: Handle OAuth errors from provider
    if error:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Authorization Denied",
                error_message=error_description or error,
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 2: Validate required parameters
    if not code:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Missing Authorization Code",
                error_message="No authorization code received from provider.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    if not state:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Missing State Parameter",
                error_message="State parameter is required for security.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 3: Validate state (CSRF protection)
    state_data = await _get_oauth_state_data(cache, state)
    if state_data is None:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Invalid or Expired State",
                error_message="Session expired. Please start the connection process again.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Extract user info from state
    user_id = UUID(state_data["user_id"])
    stored_provider = state_data.get("provider_slug")

    # Verify provider matches
    if stored_provider != provider_slug:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Mismatch",
                error_message="State does not match expected provider.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Step 4: Get provider and verify OAuth capability
    try:
        provider = get_provider(provider_slug)
    except ValueError as e:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not Found",
                error_message=str(e),
            ),
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        )

    if not is_oauth_provider(provider):
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not OAuth",
                error_message=f"Provider '{provider_slug}' does not support OAuth.",
            ),
            status_code=status.HTTP_400_BAD_REQUEST,
        )

    # Exchange authorization code for tokens (type narrowed to OAuthProviderProtocol)
    token_result = await provider.exchange_code_for_tokens(code)

    match token_result:
        case Failure(error=provider_error):
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Token Exchange Failed",
                    error_message=provider_error.message,
                ),
                status_code=status.HTTP_400_BAD_REQUEST,
            )
        case Success(value=tokens):
            pass  # Continue with tokens

    # Step 5: Encrypt tokens for storage
    token_data = {
        "access_token": tokens.access_token,
        "refresh_token": tokens.refresh_token,
        "token_type": tokens.token_type,
        "scope": tokens.scope,
    }
    encryption_result = encryption_service.encrypt(token_data)

    match encryption_result:
        case Failure():
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Encryption Failed",
                    error_message="Unable to secure credentials.",
                ),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
        case Success(value=encrypted_data):
            pass  # Continue with encrypted data

    # Step 6: Calculate token expiration
    expires_at = datetime.now(UTC) + timedelta(seconds=tokens.expires_in)

    # Step 7: Look up provider from database
    provider_entity = await provider_repo.find_by_slug(provider_slug)
    if provider_entity is None:
        return HTMLResponse(
            content=_create_error_html(
                error_title="Provider Not Configured",
                error_message=f"Provider '{provider_slug}' is not registered in the system.",
            ),
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
        )

    # Step 8: Create provider connection via command
    credentials = ProviderCredentials(
        encrypted_data=encrypted_data,
        credential_type=CredentialType.OAUTH2,
        expires_at=expires_at,
    )

    connect_command = ConnectProvider(
        user_id=user_id,
        provider_id=provider_entity.id,
        provider_slug=provider_slug,
        credentials=credentials,
        alias=state_data.get("alias"),  # Optional alias from state
    )

    connect_result = await handler.handle(connect_command)

    match connect_result:
        case Failure(error=connect_error):
            return HTMLResponse(
                content=_create_error_html(
                    error_title="Connection Failed",
                    error_message=str(connect_error),
                ),
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
        case Success():
            return HTMLResponse(
                content=_create_success_html(provider_slug),
                status_code=status.HTTP_200_OK,
            )