Source code for rapsqlite._connection_state

"""Connection state cache for sync total_changes / in_transaction (aiosqlite compat).

This module provides the cached state and wrappers that update it for
begin/commit/rollback/close/__aenter__/transaction. Used by rapsqlite.__init__
after _compat patches are applied.
"""

from __future__ import annotations

import threading
from typing import Any, TYPE_CHECKING, cast

if TYPE_CHECKING:
    from typing import Protocol

    class Connection(Protocol):
[docs] def __aenter__(self): ...
def __aexit__(self, *args): ...
[docs] async def close(self): ...
# total_changes, in_transaction, begin, commit, rollback, transaction # are patched by this module else: Connection = Any _connection_state: dict[int, dict[str, Any]] = {} _connection_state_lock = threading.Lock() # Set by apply_state(Connection) before attaching wrappers _orig_total_changes: Any = None _orig_in_transaction: Any = None _orig_begin: Any = None _orig_commit: Any = None _orig_rollback: Any = None _orig_close: Any = None _orig_aenter: Any = None _orig_transaction: Any = None def _get_conn_state(conn: Connection) -> dict[str, Any]: """Get or create state dict for a connection (thread-safe).""" cid = id(conn) with _connection_state_lock: if cid not in _connection_state: _connection_state[cid] = {"total_changes": 0, "in_transaction": False} return _connection_state[cid] def _cleanup_conn_state(conn: Connection) -> None: """Remove state for a connection (thread-safe, call on close).""" with _connection_state_lock: _connection_state.pop(id(conn), None) def _total_changes_prop(self: Connection) -> int: """Get total database changes since connection was opened (sync property for aiosqlite compat).""" return cast(int, _get_conn_state(self).get("total_changes", 0)) def _in_transaction_prop(self: Connection) -> bool: """Check if connection is in a transaction (sync property for aiosqlite compat).""" return cast(bool, _get_conn_state(self).get("in_transaction", False)) async def _update_connection_state(conn: Connection) -> None: """Update cached total_changes and in_transaction from the database.""" state = _get_conn_state(conn) try: state["total_changes"] = await _orig_total_changes(conn) except Exception: pass try: state["in_transaction"] = await _orig_in_transaction(conn) except Exception: pass async def _begin_with_state_update(self: Connection) -> None: await _orig_begin(self) _get_conn_state(self)["in_transaction"] = True async def _commit_with_state_update(self: Connection) -> None: await _orig_commit(self) state = _get_conn_state(self) state["in_transaction"] = False try: state["total_changes"] = await _orig_total_changes(self) except Exception: pass async def _rollback_with_state_update(self: Connection) -> None: await _orig_rollback(self) _get_conn_state(self)["in_transaction"] = False async def _close_with_state_cleanup(self: Connection) -> None: _cleanup_conn_state(self) await _orig_close(self) async def _aenter_with_state_init(self: Connection) -> Connection: result = await _orig_aenter(self) _get_conn_state(self) return cast(Connection, result) class _TransactionContextManagerWithState: """Wrapper around TransactionContextManager that updates in_transaction state.""" def __init__(self, conn: Connection, orig_cm: Any) -> None: self._conn = conn self._orig_cm = orig_cm async def __aenter__(self) -> Connection: result = await self._orig_cm.__aenter__() _get_conn_state(self._conn)["in_transaction"] = True return cast(Connection, result) async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self._orig_cm.__aexit__(exc_type, exc_val, exc_tb) _get_conn_state(self._conn)["in_transaction"] = False def _transaction_with_state(self: Connection) -> _TransactionContextManagerWithState: """Return a transaction context manager that updates in_transaction state.""" return _TransactionContextManagerWithState(self, _orig_transaction(self)) def apply_state(Connection: type) -> None: """Attach connection state cache and wrappers to Connection. Call after _compat.apply_compat() so commit/rollback are already no-op wrappers. """ global _orig_total_changes, _orig_in_transaction global \ _orig_begin, \ _orig_commit, \ _orig_rollback, \ _orig_close, \ _orig_aenter, \ _orig_transaction _orig_total_changes = Connection.total_changes # type: ignore[attr-defined] _orig_in_transaction = Connection.in_transaction # type: ignore[attr-defined] _orig_begin = Connection.begin # type: ignore[attr-defined] _orig_commit = Connection.commit # type: ignore[attr-defined] _orig_rollback = Connection.rollback # type: ignore[attr-defined] _orig_close = Connection.close # type: ignore[attr-defined] _orig_aenter = Connection.__aenter__ # type: ignore[attr-defined] _orig_transaction = Connection.transaction # type: ignore[attr-defined] Connection.total_changes = property(_total_changes_prop) # type: ignore[attr-defined] Connection.in_transaction = property(_in_transaction_prop) # type: ignore[attr-defined] Connection.total_changes_async = _orig_total_changes # type: ignore[attr-defined] Connection.in_transaction_async = _orig_in_transaction # type: ignore[attr-defined] Connection.begin = _begin_with_state_update # type: ignore[attr-defined] Connection.commit = _commit_with_state_update # type: ignore[attr-defined] Connection.rollback = _rollback_with_state_update # type: ignore[attr-defined] Connection.close = _close_with_state_cleanup # type: ignore[attr-defined] Connection.__aenter__ = _aenter_with_state_init # type: ignore[attr-defined] Connection.transaction = _transaction_with_state # type: ignore[attr-defined]