Skip to content

infrastructure.rate_limit.redis_storage

src.infrastructure.rate_limit.redis_storage

Redis-backed storage for rate limiting using atomic Lua scripts.

This module implements the low-level token bucket operations against Redis using an atomic Lua script (EVALSHA). It is intentionally focused on storage concerns (key shaping is done by the higher-level adapter).

Fail-open policy

All public methods return Success on infrastructure failures with conservative defaults that ALLOW requests. Actual system errors (e.g., admin reset failure) are returned as Failure(RateLimitError).

Note

This is a storage component used by the higher-level adapter that implements RateLimitProtocol. It is not exposed directly to the presentation layer.

Classes

RedisStorage

Redis storage for rate limiting with atomic Lua script.

This class loads the token bucket Lua script once and executes it via EVALSHA for atomic check/consume operations.

Parameters:

Name Type Description Default
redis_client Any

An async Redis client (redis.asyncio.Redis compatible).

required

Attributes:

Name Type Description
redis

The Redis client instance.

_lua

Cached Lua script SHAs.

Source code in src/infrastructure/rate_limit/redis_storage.py
class RedisStorage:
    """Redis storage for rate limiting with atomic Lua script.

    This class loads the token bucket Lua script once and executes it via
    EVALSHA for atomic check/consume operations.

    Args:
        redis_client: An async Redis client (redis.asyncio.Redis compatible).

    Attributes:
        redis: The Redis client instance.
        _lua: Cached Lua script SHAs.

    """

    def __init__(self, *, redis_client: Any) -> None:
        self.redis = redis_client
        self._lua = _LuaRefs()
        self._script_lock = asyncio.Lock()

    # ---------------------------------------------------------------------
    # Public API
    # ---------------------------------------------------------------------
    async def check_and_consume(
        self,
        *,
        key_base: str,
        rule: RateLimitRule,
        cost: int = 1,
        now_ts: float | None = None,
    ) -> Result[tuple[bool, float, int], RateLimitError]:
        """Atomically check and optionally consume tokens.

        Uses Lua script to ensure check-and-consume is a single atomic operation.

        Args:
            key_base: Base key for bucket (no suffix; storage adds ':tokens'/' :time').
            rule: Rate limit rule containing capacity/refill.
            cost: Tokens to consume for this request.
            now_ts: Override current timestamp in seconds (for testing). Defaults to time().

        Returns:
            Result with tuple: (allowed, retry_after_seconds, remaining_tokens)

        Fail-open:
            On Redis errors, returns Success(True, 0.0, rule.max_tokens).
        """
        try:
            sha = await self._ensure_token_bucket_script()
            now = now_ts if now_ts is not None else time()
            resp = await self.redis.evalsha(
                sha,
                1,
                key_base,
                int(rule.max_tokens),
                float(rule.refill_rate),
                int(max(0, cost)),
                float(now),
            )
            # Expect resp: [allowed(0/1), retry_after(float), remaining(int)]
            allowed = bool(resp[0])
            retry_after = float(resp[1])
            remaining = int(resp[2])
            return Success(value=(allowed, retry_after, remaining))
        except Exception:  # Fail-open
            return Success(value=(True, 0.0, rule.max_tokens))

    async def get_remaining(
        self,
        *,
        key_base: str,
        rule: RateLimitRule,
        now_ts: float | None = None,
    ) -> Result[int, RateLimitError]:
        """Get remaining tokens without consuming any.

        Implementation detail: calls Lua with cost=0 to avoid consumption.

        Fail-open:
            On Redis errors, returns Success(rule.max_tokens).
        """
        result = await self.check_and_consume(
            key_base=key_base, rule=rule, cost=0, now_ts=now_ts
        )
        match result:
            case Success(value=(_, _, remaining)):
                return Success(value=remaining)
            case _:
                # Fail-open: return max tokens on any error
                return Success(value=rule.max_tokens)

    async def reset(
        self,
        *,
        key_base: str,
        rule: RateLimitRule,
        now_ts: float | None = None,
    ) -> Result[None, RateLimitError]:
        """Reset the bucket to full capacity.

        Unlike check operations, reset should report real errors to callers.
        """
        try:
            now = now_ts if now_ts is not None else time()
            ttl = rule.ttl_seconds
            pipe = self.redis.pipeline(transaction=True)
            pipe.setex(f"{key_base}:tokens", ttl, int(rule.max_tokens))
            pipe.setex(f"{key_base}:time", ttl, float(now))
            await pipe.execute()
            return Success(value=None)
        except Exception as exc:
            return Failure(
                error=RateLimitError(
                    code=ErrorCode.RATE_LIMIT_RESET_FAILED,
                    message=f"Failed to reset rate limit for '{key_base}': {exc}",
                    details={"key_base": key_base},
                )
            )

    # ---------------------------------------------------------------------
    # Internal helpers
    # ---------------------------------------------------------------------
    async def _ensure_token_bucket_script(self) -> str:
        """Load token bucket Lua script into Redis and cache the SHA.

        Returns:
            str: Script SHA.
        """
        if self._lua.token_bucket_sha:
            return self._lua.token_bucket_sha
        async with self._script_lock:
            if self._lua.token_bucket_sha:
                return self._lua.token_bucket_sha
            script = await _read_lua_script("lua_scripts/token_bucket.lua")
            sha: str = await self.redis.script_load(script)
            self._lua.token_bucket_sha = sha
            return sha
Functions
check_and_consume async
check_and_consume(
    *,
    key_base: str,
    rule: RateLimitRule,
    cost: int = 1,
    now_ts: float | None = None
) -> Result[tuple[bool, float, int], RateLimitError]

Atomically check and optionally consume tokens.

Uses Lua script to ensure check-and-consume is a single atomic operation.

Parameters:

Name Type Description Default
key_base str

Base key for bucket (no suffix; storage adds ':tokens'/' :time').

required
rule RateLimitRule

Rate limit rule containing capacity/refill.

required
cost int

Tokens to consume for this request.

1
now_ts float | None

Override current timestamp in seconds (for testing). Defaults to time().

None

Returns:

Type Description
Result[tuple[bool, float, int], RateLimitError]

Result with tuple: (allowed, retry_after_seconds, remaining_tokens)

Fail-open

On Redis errors, returns Success(True, 0.0, rule.max_tokens).

Source code in src/infrastructure/rate_limit/redis_storage.py
async def check_and_consume(
    self,
    *,
    key_base: str,
    rule: RateLimitRule,
    cost: int = 1,
    now_ts: float | None = None,
) -> Result[tuple[bool, float, int], RateLimitError]:
    """Atomically check and optionally consume tokens.

    Uses Lua script to ensure check-and-consume is a single atomic operation.

    Args:
        key_base: Base key for bucket (no suffix; storage adds ':tokens'/' :time').
        rule: Rate limit rule containing capacity/refill.
        cost: Tokens to consume for this request.
        now_ts: Override current timestamp in seconds (for testing). Defaults to time().

    Returns:
        Result with tuple: (allowed, retry_after_seconds, remaining_tokens)

    Fail-open:
        On Redis errors, returns Success(True, 0.0, rule.max_tokens).
    """
    try:
        sha = await self._ensure_token_bucket_script()
        now = now_ts if now_ts is not None else time()
        resp = await self.redis.evalsha(
            sha,
            1,
            key_base,
            int(rule.max_tokens),
            float(rule.refill_rate),
            int(max(0, cost)),
            float(now),
        )
        # Expect resp: [allowed(0/1), retry_after(float), remaining(int)]
        allowed = bool(resp[0])
        retry_after = float(resp[1])
        remaining = int(resp[2])
        return Success(value=(allowed, retry_after, remaining))
    except Exception:  # Fail-open
        return Success(value=(True, 0.0, rule.max_tokens))
get_remaining async
get_remaining(
    *,
    key_base: str,
    rule: RateLimitRule,
    now_ts: float | None = None
) -> Result[int, RateLimitError]

Get remaining tokens without consuming any.

Implementation detail: calls Lua with cost=0 to avoid consumption.

Fail-open

On Redis errors, returns Success(rule.max_tokens).

Source code in src/infrastructure/rate_limit/redis_storage.py
async def get_remaining(
    self,
    *,
    key_base: str,
    rule: RateLimitRule,
    now_ts: float | None = None,
) -> Result[int, RateLimitError]:
    """Get remaining tokens without consuming any.

    Implementation detail: calls Lua with cost=0 to avoid consumption.

    Fail-open:
        On Redis errors, returns Success(rule.max_tokens).
    """
    result = await self.check_and_consume(
        key_base=key_base, rule=rule, cost=0, now_ts=now_ts
    )
    match result:
        case Success(value=(_, _, remaining)):
            return Success(value=remaining)
        case _:
            # Fail-open: return max tokens on any error
            return Success(value=rule.max_tokens)
reset async
reset(
    *,
    key_base: str,
    rule: RateLimitRule,
    now_ts: float | None = None
) -> Result[None, RateLimitError]

Reset the bucket to full capacity.

Unlike check operations, reset should report real errors to callers.

Source code in src/infrastructure/rate_limit/redis_storage.py
async def reset(
    self,
    *,
    key_base: str,
    rule: RateLimitRule,
    now_ts: float | None = None,
) -> Result[None, RateLimitError]:
    """Reset the bucket to full capacity.

    Unlike check operations, reset should report real errors to callers.
    """
    try:
        now = now_ts if now_ts is not None else time()
        ttl = rule.ttl_seconds
        pipe = self.redis.pipeline(transaction=True)
        pipe.setex(f"{key_base}:tokens", ttl, int(rule.max_tokens))
        pipe.setex(f"{key_base}:time", ttl, float(now))
        await pipe.execute()
        return Success(value=None)
    except Exception as exc:
        return Failure(
            error=RateLimitError(
                code=ErrorCode.RATE_LIMIT_RESET_FAILED,
                message=f"Failed to reset rate limit for '{key_base}': {exc}",
                details={"key_base": key_base},
            )
        )