Skip to content

infrastructure.cache.session_cache

src.infrastructure.cache.session_cache

Redis implementation of SessionCache protocol.

Provides fast (<5ms) session lookups via Redis caching. Uses write-through caching: writes go to both cache and database.

Key Patterns
  • session:{session_id} -> JSON serialized SessionData
  • user:{user_id}:sessions -> Redis Set of session IDs
Architecture
  • Implements SessionCache protocol (structural typing)
  • Uses RedisAdapter for low-level operations
  • Returns None on cache miss (fail-open for resilience)
  • Database is always source of truth
Reference
  • docs/architecture/session-management-architecture.md

Classes

RedisSessionCache

Redis implementation of SessionCache protocol.

Provides fast session lookups and maintains user->sessions index for efficient bulk operations.

Note: Does NOT inherit from SessionCache protocol (uses structural typing).

Key Patterns
  • session:{session_id} -> Full session data (JSON)
  • user:{user_id}:sessions -> Set of session IDs

Attributes:

Name Type Description
_cache

Cache instance implementing CacheProtocol.

Source code in src/infrastructure/cache/session_cache.py
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
class RedisSessionCache:
    """Redis implementation of SessionCache protocol.

    Provides fast session lookups and maintains user->sessions index
    for efficient bulk operations.

    Note: Does NOT inherit from SessionCache protocol (uses structural typing).

    Key Patterns:
        - session:{session_id} -> Full session data (JSON)
        - user:{user_id}:sessions -> Set of session IDs

    Attributes:
        _cache: Cache instance implementing CacheProtocol.
    """

    def __init__(self, cache: CacheProtocol) -> None:
        """Initialize session cache.

        Args:
            cache: Cache instance implementing CacheProtocol.
        """
        self._cache = cache

    def _session_key(self, session_id: UUID) -> str:
        """Generate cache key for session data.

        Args:
            session_id: Session identifier.

        Returns:
            Cache key string.
        """
        return f"session:{session_id}"

    def _user_sessions_key(self, user_id: UUID) -> str:
        """Generate cache key for user's session set.

        Args:
            user_id: User identifier.

        Returns:
            Cache key string.
        """
        return f"user:{user_id}:sessions"

    async def get(self, session_id: UUID) -> SessionData | None:
        """Get session data from cache.

        Args:
            session_id: Session identifier.

        Returns:
            SessionData if cached, None otherwise (cache miss or error).
        """
        key = self._session_key(session_id)
        result = await self._cache.get_json(key)

        match result:
            case Success(value=None):
                return None
            case Success(value=data) if data is not None:
                try:
                    return self._from_dict(data)
                except (KeyError, TypeError, ValueError) as e:
                    logger.warning(
                        "Failed to deserialize session from cache",
                        extra={"session_id": str(session_id), "error": str(e)},
                    )
                    return None
            case _:
                # Cache error - fail open (return None)
                logger.warning(
                    "Cache error getting session",
                    extra={"session_id": str(session_id)},
                )
                return None

    async def set(
        self,
        session_data: SessionData,
        *,
        ttl_seconds: int | None = None,
    ) -> None:
        """Store session data in cache.

        Also maintains the user->sessions index.

        Args:
            session_data: Session data to cache.
            ttl_seconds: Cache TTL in seconds. If None, calculates from
                session expires_at. Defaults to 30 days if no expiry.
        """
        # Calculate TTL
        if ttl_seconds is None:
            if session_data.expires_at is not None:
                now = datetime.now(UTC)
                # Handle timezone-naive expires_at
                expires_at = session_data.expires_at
                if expires_at.tzinfo is None:
                    expires_at = expires_at.replace(tzinfo=UTC)
                ttl_seconds = max(1, int((expires_at - now).total_seconds()))
            else:
                ttl_seconds = DEFAULT_SESSION_TTL

        # Store session data
        key = self._session_key(session_data.id)
        data = self._to_dict(session_data)

        result = await self._cache.set_json(key, data, ttl=ttl_seconds)
        if not isinstance(result, Success):
            logger.warning(
                "Failed to cache session",
                extra={"session_id": str(session_data.id)},
            )
            return

        # Add to user's session index
        await self.add_user_session(session_data.user_id, session_data.id)

    async def delete(self, session_id: UUID) -> bool:
        """Remove session from cache.

        Note: Does NOT remove from user index (caller should use remove_user_session).

        Args:
            session_id: Session identifier.

        Returns:
            True if deleted, False if not found or error.
        """
        key = self._session_key(session_id)
        result = await self._cache.delete(key)

        match result:
            case Success(value=deleted):
                return deleted
            case _:
                logger.warning(
                    "Cache error deleting session",
                    extra={"session_id": str(session_id)},
                )
                return False

    async def delete_all_for_user(self, user_id: UUID) -> int:
        """Remove all sessions for a user from cache.

        Removes session data and clears user's session index.

        Args:
            user_id: User identifier.

        Returns:
            Number of sessions removed from cache.
        """
        # Get all session IDs for user
        session_ids = await self.get_user_session_ids(user_id)

        if not session_ids:
            return 0

        # Delete each session
        deleted_count = 0
        for session_id in session_ids:
            if await self.delete(session_id):
                deleted_count += 1

        # Clear the user's session index
        user_key = self._user_sessions_key(user_id)
        await self._cache.delete(user_key)

        return deleted_count

    async def exists(self, session_id: UUID) -> bool:
        """Check if session exists in cache (quick validation).

        Args:
            session_id: Session identifier.

        Returns:
            True if session exists in cache, False otherwise.
        """
        key = self._session_key(session_id)
        result = await self._cache.exists(key)

        match result:
            case Success(value=exists):
                return exists
            case _:
                return False

    async def get_user_session_ids(self, user_id: UUID) -> list[UUID]:
        """Get all session IDs for a user from cache.

        Args:
            user_id: User identifier.

        Returns:
            List of session IDs, empty if none cached or error.
        """
        key = self._user_sessions_key(user_id)
        result = await self._cache.get(key)

        match result:
            case Success(value=None):
                return []
            case Success(value=data) if data is not None:
                try:
                    # Stored as JSON array of UUID strings
                    ids = json.loads(data)
                    return [UUID(id_str) for id_str in ids]
                except (json.JSONDecodeError, ValueError) as e:
                    logger.warning(
                        "Failed to parse user session IDs from cache",
                        extra={"user_id": str(user_id), "error": str(e)},
                    )
                    return []
            case _:
                return []

    async def add_user_session(self, user_id: UUID, session_id: UUID) -> None:
        """Add session ID to user's session set.

        Args:
            user_id: User identifier.
            session_id: Session identifier.
        """
        key = self._user_sessions_key(user_id)

        # Get current list
        current_ids = await self.get_user_session_ids(user_id)

        # Add new session if not already present
        if session_id not in current_ids:
            current_ids.append(session_id)

        # Store updated list
        data = json.dumps([str(sid) for sid in current_ids])
        await self._cache.set(key, data, ttl=DEFAULT_SESSION_TTL)

    async def remove_user_session(self, user_id: UUID, session_id: UUID) -> None:
        """Remove session ID from user's session set.

        Args:
            user_id: User identifier.
            session_id: Session identifier.
        """
        key = self._user_sessions_key(user_id)

        # Get current list
        current_ids = await self.get_user_session_ids(user_id)

        # Remove session if present
        if session_id in current_ids:
            current_ids.remove(session_id)

        if current_ids:
            # Store updated list
            data = json.dumps([str(sid) for sid in current_ids])
            await self._cache.set(key, data, ttl=DEFAULT_SESSION_TTL)
        else:
            # No sessions left, delete the key
            await self._cache.delete(key)

    async def update_last_activity(
        self,
        session_id: UUID,
        ip_address: str | None = None,
    ) -> bool:
        """Update session's last activity in cache.

        Lightweight update - only modifies last_activity_at and optionally last_ip_address.

        Args:
            session_id: Session identifier.
            ip_address: Current IP address (optional).

        Returns:
            True if updated, False if session not in cache.
        """
        # Get current session data
        session_data = await self.get(session_id)
        if session_data is None:
            return False

        # Create updated copy
        now = datetime.now(UTC)
        updated_data = SessionData(
            id=session_data.id,
            user_id=session_data.user_id,
            device_info=session_data.device_info,
            user_agent=session_data.user_agent,
            ip_address=session_data.ip_address,
            location=session_data.location,
            created_at=session_data.created_at,
            last_activity_at=now,
            expires_at=session_data.expires_at,
            is_revoked=session_data.is_revoked,
            is_trusted=session_data.is_trusted,
            revoked_at=session_data.revoked_at,
            revoked_reason=session_data.revoked_reason,
            refresh_token_id=session_data.refresh_token_id,
            last_ip_address=ip_address if ip_address else session_data.last_ip_address,
            suspicious_activity_count=session_data.suspicious_activity_count,
            last_provider_accessed=session_data.last_provider_accessed,
            last_provider_sync_at=session_data.last_provider_sync_at,
            providers_accessed=session_data.providers_accessed,
        )

        # Store updated data (preserves existing TTL by calculating from expires_at)
        await self.set(updated_data)
        return True

    # =========================================================================
    # Serialization helpers
    # =========================================================================

    def _to_dict(self, session_data: SessionData) -> dict[str, object]:
        """Convert SessionData to dict for JSON serialization.

        Args:
            session_data: SessionData to convert.

        Returns:
            Dictionary representation.
        """
        data = asdict(session_data)

        # Convert UUID to string
        data["id"] = str(session_data.id)
        data["user_id"] = str(session_data.user_id)

        if session_data.refresh_token_id:
            data["refresh_token_id"] = str(session_data.refresh_token_id)

        # Convert datetime to ISO string
        for dt_field in [
            "created_at",
            "last_activity_at",
            "expires_at",
            "revoked_at",
            "last_provider_sync_at",
        ]:
            if data.get(dt_field) is not None:
                data[dt_field] = data[dt_field].isoformat()

        return data

    def _from_dict(self, data: dict[str, object]) -> SessionData:
        """Convert dict to SessionData.

        Args:
            data: Dictionary from cache.

        Returns:
            SessionData instance.

        Raises:
            KeyError: If required field missing.
            ValueError: If UUID or datetime parsing fails.
        """
        # Parse UUIDs - cast to str for UUID constructor
        session_id = UUID(str(data["id"]))
        user_id = UUID(str(data["user_id"]))
        refresh_token_id_raw = data.get("refresh_token_id")
        refresh_token_id = (
            UUID(str(refresh_token_id_raw)) if refresh_token_id_raw else None
        )

        # Parse datetimes with explicit string cast
        def parse_dt(val: object | None) -> datetime | None:
            if val is None:
                return None
            return datetime.fromisoformat(str(val))

        # Extract values with explicit type casting
        device_info = str(data["device_info"]) if data.get("device_info") else None
        user_agent = str(data["user_agent"]) if data.get("user_agent") else None
        ip_address = str(data["ip_address"]) if data.get("ip_address") else None
        location = str(data["location"]) if data.get("location") else None
        revoked_reason = (
            str(data["revoked_reason"]) if data.get("revoked_reason") else None
        )
        last_ip = str(data["last_ip_address"]) if data.get("last_ip_address") else None
        last_provider = (
            str(data["last_provider_accessed"])
            if data.get("last_provider_accessed")
            else None
        )

        # Extract boolean and int with defaults
        is_revoked = bool(data.get("is_revoked", False))
        is_trusted = bool(data.get("is_trusted", False))
        suspicious_raw = data.get("suspicious_activity_count", 0)
        # Value is int from JSON - safe to cast
        suspicious_count = int(suspicious_raw) if suspicious_raw else 0  # type: ignore[call-overload]

        # Extract list with proper typing
        providers_raw = data.get("providers_accessed")
        providers_accessed: list[str] | None = (
            list(providers_raw) if isinstance(providers_raw, list) else None
        )

        return SessionData(
            id=session_id,
            user_id=user_id,
            device_info=device_info,
            user_agent=user_agent,
            ip_address=ip_address,
            location=location,
            created_at=parse_dt(data.get("created_at")),
            last_activity_at=parse_dt(data.get("last_activity_at")),
            expires_at=parse_dt(data.get("expires_at")),
            is_revoked=is_revoked,
            is_trusted=is_trusted,
            revoked_at=parse_dt(data.get("revoked_at")),
            revoked_reason=revoked_reason,
            refresh_token_id=refresh_token_id,
            last_ip_address=last_ip,
            suspicious_activity_count=suspicious_count,
            last_provider_accessed=last_provider,
            last_provider_sync_at=parse_dt(data.get("last_provider_sync_at")),
            providers_accessed=providers_accessed,
        )
Functions
__init__
__init__(cache: CacheProtocol) -> None

Parameters:

Name Type Description Default
cache CacheProtocol

Cache instance implementing CacheProtocol.

required
Source code in src/infrastructure/cache/session_cache.py
def __init__(self, cache: CacheProtocol) -> None:
    """Initialize session cache.

    Args:
        cache: Cache instance implementing CacheProtocol.
    """
    self._cache = cache
get async
get(session_id: UUID) -> SessionData | None

Get session data from cache.

Parameters:

Name Type Description Default
session_id UUID

Session identifier.

required

Returns:

Type Description
SessionData | None

SessionData if cached, None otherwise (cache miss or error).

Source code in src/infrastructure/cache/session_cache.py
async def get(self, session_id: UUID) -> SessionData | None:
    """Get session data from cache.

    Args:
        session_id: Session identifier.

    Returns:
        SessionData if cached, None otherwise (cache miss or error).
    """
    key = self._session_key(session_id)
    result = await self._cache.get_json(key)

    match result:
        case Success(value=None):
            return None
        case Success(value=data) if data is not None:
            try:
                return self._from_dict(data)
            except (KeyError, TypeError, ValueError) as e:
                logger.warning(
                    "Failed to deserialize session from cache",
                    extra={"session_id": str(session_id), "error": str(e)},
                )
                return None
        case _:
            # Cache error - fail open (return None)
            logger.warning(
                "Cache error getting session",
                extra={"session_id": str(session_id)},
            )
            return None
set async
set(
    session_data: SessionData,
    *,
    ttl_seconds: int | None = None
) -> None

Store session data in cache.

Also maintains the user->sessions index.

Parameters:

Name Type Description Default
session_data SessionData

Session data to cache.

required
ttl_seconds int | None

Cache TTL in seconds. If None, calculates from session expires_at. Defaults to 30 days if no expiry.

None
Source code in src/infrastructure/cache/session_cache.py
async def set(
    self,
    session_data: SessionData,
    *,
    ttl_seconds: int | None = None,
) -> None:
    """Store session data in cache.

    Also maintains the user->sessions index.

    Args:
        session_data: Session data to cache.
        ttl_seconds: Cache TTL in seconds. If None, calculates from
            session expires_at. Defaults to 30 days if no expiry.
    """
    # Calculate TTL
    if ttl_seconds is None:
        if session_data.expires_at is not None:
            now = datetime.now(UTC)
            # Handle timezone-naive expires_at
            expires_at = session_data.expires_at
            if expires_at.tzinfo is None:
                expires_at = expires_at.replace(tzinfo=UTC)
            ttl_seconds = max(1, int((expires_at - now).total_seconds()))
        else:
            ttl_seconds = DEFAULT_SESSION_TTL

    # Store session data
    key = self._session_key(session_data.id)
    data = self._to_dict(session_data)

    result = await self._cache.set_json(key, data, ttl=ttl_seconds)
    if not isinstance(result, Success):
        logger.warning(
            "Failed to cache session",
            extra={"session_id": str(session_data.id)},
        )
        return

    # Add to user's session index
    await self.add_user_session(session_data.user_id, session_data.id)
delete async
delete(session_id: UUID) -> bool

Remove session from cache.

Note: Does NOT remove from user index (caller should use remove_user_session).

Parameters:

Name Type Description Default
session_id UUID

Session identifier.

required

Returns:

Type Description
bool

True if deleted, False if not found or error.

Source code in src/infrastructure/cache/session_cache.py
async def delete(self, session_id: UUID) -> bool:
    """Remove session from cache.

    Note: Does NOT remove from user index (caller should use remove_user_session).

    Args:
        session_id: Session identifier.

    Returns:
        True if deleted, False if not found or error.
    """
    key = self._session_key(session_id)
    result = await self._cache.delete(key)

    match result:
        case Success(value=deleted):
            return deleted
        case _:
            logger.warning(
                "Cache error deleting session",
                extra={"session_id": str(session_id)},
            )
            return False
delete_all_for_user async
delete_all_for_user(user_id: UUID) -> int

Remove all sessions for a user from cache.

Removes session data and clears user's session index.

Parameters:

Name Type Description Default
user_id UUID

User identifier.

required

Returns:

Type Description
int

Number of sessions removed from cache.

Source code in src/infrastructure/cache/session_cache.py
async def delete_all_for_user(self, user_id: UUID) -> int:
    """Remove all sessions for a user from cache.

    Removes session data and clears user's session index.

    Args:
        user_id: User identifier.

    Returns:
        Number of sessions removed from cache.
    """
    # Get all session IDs for user
    session_ids = await self.get_user_session_ids(user_id)

    if not session_ids:
        return 0

    # Delete each session
    deleted_count = 0
    for session_id in session_ids:
        if await self.delete(session_id):
            deleted_count += 1

    # Clear the user's session index
    user_key = self._user_sessions_key(user_id)
    await self._cache.delete(user_key)

    return deleted_count
exists async
exists(session_id: UUID) -> bool

Check if session exists in cache (quick validation).

Parameters:

Name Type Description Default
session_id UUID

Session identifier.

required

Returns:

Type Description
bool

True if session exists in cache, False otherwise.

Source code in src/infrastructure/cache/session_cache.py
async def exists(self, session_id: UUID) -> bool:
    """Check if session exists in cache (quick validation).

    Args:
        session_id: Session identifier.

    Returns:
        True if session exists in cache, False otherwise.
    """
    key = self._session_key(session_id)
    result = await self._cache.exists(key)

    match result:
        case Success(value=exists):
            return exists
        case _:
            return False
get_user_session_ids async
get_user_session_ids(user_id: UUID) -> list[UUID]

Get all session IDs for a user from cache.

Parameters:

Name Type Description Default
user_id UUID

User identifier.

required

Returns:

Type Description
list[UUID]

List of session IDs, empty if none cached or error.

Source code in src/infrastructure/cache/session_cache.py
async def get_user_session_ids(self, user_id: UUID) -> list[UUID]:
    """Get all session IDs for a user from cache.

    Args:
        user_id: User identifier.

    Returns:
        List of session IDs, empty if none cached or error.
    """
    key = self._user_sessions_key(user_id)
    result = await self._cache.get(key)

    match result:
        case Success(value=None):
            return []
        case Success(value=data) if data is not None:
            try:
                # Stored as JSON array of UUID strings
                ids = json.loads(data)
                return [UUID(id_str) for id_str in ids]
            except (json.JSONDecodeError, ValueError) as e:
                logger.warning(
                    "Failed to parse user session IDs from cache",
                    extra={"user_id": str(user_id), "error": str(e)},
                )
                return []
        case _:
            return []
add_user_session async
add_user_session(user_id: UUID, session_id: UUID) -> None

Add session ID to user's session set.

Parameters:

Name Type Description Default
user_id UUID

User identifier.

required
session_id UUID

Session identifier.

required
Source code in src/infrastructure/cache/session_cache.py
async def add_user_session(self, user_id: UUID, session_id: UUID) -> None:
    """Add session ID to user's session set.

    Args:
        user_id: User identifier.
        session_id: Session identifier.
    """
    key = self._user_sessions_key(user_id)

    # Get current list
    current_ids = await self.get_user_session_ids(user_id)

    # Add new session if not already present
    if session_id not in current_ids:
        current_ids.append(session_id)

    # Store updated list
    data = json.dumps([str(sid) for sid in current_ids])
    await self._cache.set(key, data, ttl=DEFAULT_SESSION_TTL)
remove_user_session async
remove_user_session(
    user_id: UUID, session_id: UUID
) -> None

Remove session ID from user's session set.

Parameters:

Name Type Description Default
user_id UUID

User identifier.

required
session_id UUID

Session identifier.

required
Source code in src/infrastructure/cache/session_cache.py
async def remove_user_session(self, user_id: UUID, session_id: UUID) -> None:
    """Remove session ID from user's session set.

    Args:
        user_id: User identifier.
        session_id: Session identifier.
    """
    key = self._user_sessions_key(user_id)

    # Get current list
    current_ids = await self.get_user_session_ids(user_id)

    # Remove session if present
    if session_id in current_ids:
        current_ids.remove(session_id)

    if current_ids:
        # Store updated list
        data = json.dumps([str(sid) for sid in current_ids])
        await self._cache.set(key, data, ttl=DEFAULT_SESSION_TTL)
    else:
        # No sessions left, delete the key
        await self._cache.delete(key)
update_last_activity async
update_last_activity(
    session_id: UUID, ip_address: str | None = None
) -> bool

Update session's last activity in cache.

Lightweight update - only modifies last_activity_at and optionally last_ip_address.

Parameters:

Name Type Description Default
session_id UUID

Session identifier.

required
ip_address str | None

Current IP address (optional).

None

Returns:

Type Description
bool

True if updated, False if session not in cache.

Source code in src/infrastructure/cache/session_cache.py
async def update_last_activity(
    self,
    session_id: UUID,
    ip_address: str | None = None,
) -> bool:
    """Update session's last activity in cache.

    Lightweight update - only modifies last_activity_at and optionally last_ip_address.

    Args:
        session_id: Session identifier.
        ip_address: Current IP address (optional).

    Returns:
        True if updated, False if session not in cache.
    """
    # Get current session data
    session_data = await self.get(session_id)
    if session_data is None:
        return False

    # Create updated copy
    now = datetime.now(UTC)
    updated_data = SessionData(
        id=session_data.id,
        user_id=session_data.user_id,
        device_info=session_data.device_info,
        user_agent=session_data.user_agent,
        ip_address=session_data.ip_address,
        location=session_data.location,
        created_at=session_data.created_at,
        last_activity_at=now,
        expires_at=session_data.expires_at,
        is_revoked=session_data.is_revoked,
        is_trusted=session_data.is_trusted,
        revoked_at=session_data.revoked_at,
        revoked_reason=session_data.revoked_reason,
        refresh_token_id=session_data.refresh_token_id,
        last_ip_address=ip_address if ip_address else session_data.last_ip_address,
        suspicious_activity_count=session_data.suspicious_activity_count,
        last_provider_accessed=session_data.last_provider_accessed,
        last_provider_sync_at=session_data.last_provider_sync_at,
        providers_accessed=session_data.providers_accessed,
    )

    # Store updated data (preserves existing TTL by calculating from expires_at)
    await self.set(updated_data)
    return True