Skip to content

infrastructure.persistence.repositories.refresh_token_repository

src.infrastructure.persistence.repositories.refresh_token_repository

RefreshTokenRepository - SQLAlchemy implementation for refresh token persistence.

Handles CRUD operations for refresh tokens with automatic expiration checks.

Classes

RefreshTokenRepository

SQLAlchemy implementation for refresh token persistence.

Manages refresh tokens with support for: - Token creation and storage - Token validation (hash lookup) - Token rotation (delete old, create new) - Session-based revocation

Attributes:

Name Type Description
session

SQLAlchemy async session for database operations.

Example

async with get_session() as session: ... repo = RefreshTokenRepository(session) ... token = await repo.find_by_token_hash(token_hash)

Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
class RefreshTokenRepository:
    """SQLAlchemy implementation for refresh token persistence.

    Manages refresh tokens with support for:
    - Token creation and storage
    - Token validation (hash lookup)
    - Token rotation (delete old, create new)
    - Session-based revocation

    Attributes:
        session: SQLAlchemy async session for database operations.

    Example:
        >>> async with get_session() as session:
        ...     repo = RefreshTokenRepository(session)
        ...     token = await repo.find_by_token_hash(token_hash)
    """

    def __init__(self, session: AsyncSession) -> None:
        """Initialize repository with database session.

        Args:
            session: SQLAlchemy async session.
        """
        self.session = session

    async def save(
        self,
        user_id: UUID,
        token_hash: str,
        session_id: UUID,
        expires_at: datetime,
        *,
        token_version: int = 1,
        global_version_at_issuance: int = 1,
    ) -> RefreshTokenData:
        """Create new refresh token in database.

        Args:
            user_id: User's unique identifier.
            token_hash: Bcrypt hash of the refresh token.
            session_id: Associated session ID.
            expires_at: Token expiration timestamp.
            token_version: Token version at issuance (for breach rotation).
            global_version_at_issuance: Global min version when issued.

        Returns:
            Created RefreshTokenData.
        """
        token_model = RefreshToken(
            user_id=user_id,
            token_hash=token_hash,
            session_id=session_id,
            expires_at=expires_at,
            token_version=token_version,
            global_version_at_issuance=global_version_at_issuance,
        )
        self.session.add(token_model)
        await self.session.commit()
        await self.session.refresh(token_model)
        return _to_data(token_model)

    async def find_by_token_hash(self, token_hash: str) -> RefreshTokenData | None:
        """Find refresh token by hash.

        Args:
            token_hash: Bcrypt hash of the token.

        Returns:
            RefreshTokenData if found and not revoked, None otherwise.
        """
        stmt = (
            select(RefreshToken)
            .where(RefreshToken.token_hash == token_hash)
            .where(RefreshToken.revoked_at.is_(None))
        )
        result = await self.session.execute(stmt)
        model = result.scalar_one_or_none()
        return _to_data(model) if model else None

    async def find_by_id(self, token_id: UUID) -> RefreshTokenData | None:
        """Find refresh token by ID.

        Args:
            token_id: Token's unique identifier.

        Returns:
            RefreshTokenData if found, None otherwise.
        """
        stmt = select(RefreshToken).where(RefreshToken.id == token_id)
        result = await self.session.execute(stmt)
        model = result.scalar_one_or_none()
        return _to_data(model) if model else None

    async def update_last_used(self, token_id: UUID) -> None:
        """Update last_used_at timestamp.

        Args:
            token_id: Token's unique identifier.
        """
        stmt = select(RefreshToken).where(RefreshToken.id == token_id)
        result = await self.session.execute(stmt)
        token = result.scalar_one()

        token.last_used_at = datetime.now(UTC)
        await self.session.commit()

    async def delete(self, token_id: UUID) -> None:
        """Delete refresh token (for rotation).

        Args:
            token_id: Token's unique identifier.
        """
        stmt = select(RefreshToken).where(RefreshToken.id == token_id)
        result = await self.session.execute(stmt)
        token = result.scalar_one()

        await self.session.delete(token)
        await self.session.commit()

    async def revoke_by_session(self, session_id: UUID) -> None:
        """Revoke all refresh tokens for a session.

        Args:
            session_id: Session ID to revoke tokens for.
        """
        stmt = (
            select(RefreshToken)
            .where(RefreshToken.session_id == session_id)
            .where(RefreshToken.revoked_at.is_(None))
        )
        result = await self.session.execute(stmt)
        tokens = result.scalars().all()

        for token in tokens:
            token.revoked_at = datetime.now(UTC)
            token.revoked_reason = "session_revoked"

        await self.session.commit()

    async def revoke_all_for_user(
        self,
        user_id: UUID,
        reason: str = "user_requested",
    ) -> None:
        """Revoke all refresh tokens for a user.

        Used when password changes or user logs out of all devices.

        Args:
            user_id: User's unique identifier.
            reason: Reason for revocation (for audit).
        """
        stmt = (
            select(RefreshToken)
            .where(RefreshToken.user_id == user_id)
            .where(RefreshToken.revoked_at.is_(None))
        )
        result = await self.session.execute(stmt)
        tokens = result.scalars().all()

        for token in tokens:
            token.revoked_at = datetime.now(UTC)
            token.revoked_reason = reason

        await self.session.commit()

    async def find_by_token_verification(
        self,
        token: str,
        verify_fn: Callable[[str, str], bool],
    ) -> RefreshTokenData | None:
        """Find refresh token by verifying against stored hashes.

        Since bcrypt hashes are non-deterministic, we iterate through active
        tokens and verify each one against the provided token.

        Args:
            token: Plain refresh token from user request.
            verify_fn: Function to verify token against hash (token, hash) -> bool.

        Returns:
            RefreshTokenData if found and verified, None otherwise.
        """
        # Get all active (non-revoked) tokens
        stmt = select(RefreshToken).where(RefreshToken.revoked_at.is_(None))
        result = await self.session.execute(stmt)
        tokens = result.scalars().all()

        # Verify each token against the provided token
        for token_model in tokens:
            if verify_fn(token, token_model.token_hash):
                return _to_data(token_model)

        return None
Functions
__init__
__init__(session: AsyncSession) -> None

Parameters:

Name Type Description Default
session AsyncSession

SQLAlchemy async session.

required
Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
def __init__(self, session: AsyncSession) -> None:
    """Initialize repository with database session.

    Args:
        session: SQLAlchemy async session.
    """
    self.session = session
save async
save(
    user_id: UUID,
    token_hash: str,
    session_id: UUID,
    expires_at: datetime,
    *,
    token_version: int = 1,
    global_version_at_issuance: int = 1
) -> RefreshTokenData

Create new refresh token in database.

Parameters:

Name Type Description Default
user_id UUID

User's unique identifier.

required
token_hash str

Bcrypt hash of the refresh token.

required
session_id UUID

Associated session ID.

required
expires_at datetime

Token expiration timestamp.

required
token_version int

Token version at issuance (for breach rotation).

1
global_version_at_issuance int

Global min version when issued.

1

Returns:

Type Description
RefreshTokenData

Created RefreshTokenData.

Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def save(
    self,
    user_id: UUID,
    token_hash: str,
    session_id: UUID,
    expires_at: datetime,
    *,
    token_version: int = 1,
    global_version_at_issuance: int = 1,
) -> RefreshTokenData:
    """Create new refresh token in database.

    Args:
        user_id: User's unique identifier.
        token_hash: Bcrypt hash of the refresh token.
        session_id: Associated session ID.
        expires_at: Token expiration timestamp.
        token_version: Token version at issuance (for breach rotation).
        global_version_at_issuance: Global min version when issued.

    Returns:
        Created RefreshTokenData.
    """
    token_model = RefreshToken(
        user_id=user_id,
        token_hash=token_hash,
        session_id=session_id,
        expires_at=expires_at,
        token_version=token_version,
        global_version_at_issuance=global_version_at_issuance,
    )
    self.session.add(token_model)
    await self.session.commit()
    await self.session.refresh(token_model)
    return _to_data(token_model)
find_by_token_hash async
find_by_token_hash(
    token_hash: str,
) -> RefreshTokenData | None

Find refresh token by hash.

Parameters:

Name Type Description Default
token_hash str

Bcrypt hash of the token.

required

Returns:

Type Description
RefreshTokenData | None

RefreshTokenData if found and not revoked, None otherwise.

Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def find_by_token_hash(self, token_hash: str) -> RefreshTokenData | None:
    """Find refresh token by hash.

    Args:
        token_hash: Bcrypt hash of the token.

    Returns:
        RefreshTokenData if found and not revoked, None otherwise.
    """
    stmt = (
        select(RefreshToken)
        .where(RefreshToken.token_hash == token_hash)
        .where(RefreshToken.revoked_at.is_(None))
    )
    result = await self.session.execute(stmt)
    model = result.scalar_one_or_none()
    return _to_data(model) if model else None
find_by_id async
find_by_id(token_id: UUID) -> RefreshTokenData | None

Find refresh token by ID.

Parameters:

Name Type Description Default
token_id UUID

Token's unique identifier.

required

Returns:

Type Description
RefreshTokenData | None

RefreshTokenData if found, None otherwise.

Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def find_by_id(self, token_id: UUID) -> RefreshTokenData | None:
    """Find refresh token by ID.

    Args:
        token_id: Token's unique identifier.

    Returns:
        RefreshTokenData if found, None otherwise.
    """
    stmt = select(RefreshToken).where(RefreshToken.id == token_id)
    result = await self.session.execute(stmt)
    model = result.scalar_one_or_none()
    return _to_data(model) if model else None
update_last_used async
update_last_used(token_id: UUID) -> None

Update last_used_at timestamp.

Parameters:

Name Type Description Default
token_id UUID

Token's unique identifier.

required
Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def update_last_used(self, token_id: UUID) -> None:
    """Update last_used_at timestamp.

    Args:
        token_id: Token's unique identifier.
    """
    stmt = select(RefreshToken).where(RefreshToken.id == token_id)
    result = await self.session.execute(stmt)
    token = result.scalar_one()

    token.last_used_at = datetime.now(UTC)
    await self.session.commit()
delete async
delete(token_id: UUID) -> None

Delete refresh token (for rotation).

Parameters:

Name Type Description Default
token_id UUID

Token's unique identifier.

required
Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def delete(self, token_id: UUID) -> None:
    """Delete refresh token (for rotation).

    Args:
        token_id: Token's unique identifier.
    """
    stmt = select(RefreshToken).where(RefreshToken.id == token_id)
    result = await self.session.execute(stmt)
    token = result.scalar_one()

    await self.session.delete(token)
    await self.session.commit()
revoke_by_session async
revoke_by_session(session_id: UUID) -> None

Revoke all refresh tokens for a session.

Parameters:

Name Type Description Default
session_id UUID

Session ID to revoke tokens for.

required
Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def revoke_by_session(self, session_id: UUID) -> None:
    """Revoke all refresh tokens for a session.

    Args:
        session_id: Session ID to revoke tokens for.
    """
    stmt = (
        select(RefreshToken)
        .where(RefreshToken.session_id == session_id)
        .where(RefreshToken.revoked_at.is_(None))
    )
    result = await self.session.execute(stmt)
    tokens = result.scalars().all()

    for token in tokens:
        token.revoked_at = datetime.now(UTC)
        token.revoked_reason = "session_revoked"

    await self.session.commit()
revoke_all_for_user async
revoke_all_for_user(
    user_id: UUID, reason: str = "user_requested"
) -> None

Revoke all refresh tokens for a user.

Used when password changes or user logs out of all devices.

Parameters:

Name Type Description Default
user_id UUID

User's unique identifier.

required
reason str

Reason for revocation (for audit).

'user_requested'
Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def revoke_all_for_user(
    self,
    user_id: UUID,
    reason: str = "user_requested",
) -> None:
    """Revoke all refresh tokens for a user.

    Used when password changes or user logs out of all devices.

    Args:
        user_id: User's unique identifier.
        reason: Reason for revocation (for audit).
    """
    stmt = (
        select(RefreshToken)
        .where(RefreshToken.user_id == user_id)
        .where(RefreshToken.revoked_at.is_(None))
    )
    result = await self.session.execute(stmt)
    tokens = result.scalars().all()

    for token in tokens:
        token.revoked_at = datetime.now(UTC)
        token.revoked_reason = reason

    await self.session.commit()
find_by_token_verification async
find_by_token_verification(
    token: str, verify_fn: Callable[[str, str], bool]
) -> RefreshTokenData | None

Find refresh token by verifying against stored hashes.

Since bcrypt hashes are non-deterministic, we iterate through active tokens and verify each one against the provided token.

Parameters:

Name Type Description Default
token str

Plain refresh token from user request.

required
verify_fn Callable[[str, str], bool]

Function to verify token against hash (token, hash) -> bool.

required

Returns:

Type Description
RefreshTokenData | None

RefreshTokenData if found and verified, None otherwise.

Source code in src/infrastructure/persistence/repositories/refresh_token_repository.py
async def find_by_token_verification(
    self,
    token: str,
    verify_fn: Callable[[str, str], bool],
) -> RefreshTokenData | None:
    """Find refresh token by verifying against stored hashes.

    Since bcrypt hashes are non-deterministic, we iterate through active
    tokens and verify each one against the provided token.

    Args:
        token: Plain refresh token from user request.
        verify_fn: Function to verify token against hash (token, hash) -> bool.

    Returns:
        RefreshTokenData if found and verified, None otherwise.
    """
    # Get all active (non-revoked) tokens
    stmt = select(RefreshToken).where(RefreshToken.revoked_at.is_(None))
    result = await self.session.execute(stmt)
    tokens = result.scalars().all()

    # Verify each token against the provided token
    for token_model in tokens:
        if verify_fn(token, token_model.token_hash):
            return _to_data(token_model)

    return None