Source code for hsmkey.session

"""PKCS#11 session management for hsmkey module."""

from __future__ import annotations

import threading
from contextlib import contextmanager
from typing import TYPE_CHECKING, Iterator

import pkcs11
from pkcs11 import KeyType, Mechanism, ObjectClass

from .exceptions import (
    HSMKeyNotFoundError,
    HSMPinError,
    HSMSessionError,
)

if TYPE_CHECKING:
    from pkcs11 import Session, Token


[docs] class SessionPool: """Thread-safe pool for PKCS#11 sessions. This class manages PKCS#11 library instances and sessions, providing thread-safe access with reference counting. """ _lib_cache: dict[str, pkcs11.lib] = {} _lib_lock = threading.Lock() def __init__( self, module_path: str, token_label: str, user_pin: str | None = None, so_pin: str | None = None, ) -> None: """Initialize session pool. Args: module_path: Path to PKCS#11 library token_label: Label of token to use user_pin: User PIN for authentication so_pin: Security Officer PIN for admin operations """ self.module_path = module_path self.token_label = token_label self.user_pin = user_pin self.so_pin = so_pin self._session: Session | None = None self._lock = threading.Lock() def _get_lib(self) -> pkcs11.lib: """Get or create PKCS#11 library instance.""" with self._lib_lock: if self.module_path not in self._lib_cache: try: self._lib_cache[self.module_path] = pkcs11.lib(self.module_path) except Exception as e: raise HSMSessionError( f"Failed to load PKCS#11 library: {e}" ) from e return self._lib_cache[self.module_path] def _get_token(self) -> Token: """Get token by label.""" lib = self._get_lib() try: return lib.get_token(token_label=self.token_label) except pkcs11.NoSuchToken as e: raise HSMSessionError( f"Token not found: {self.token_label}" ) from e
[docs] def open_session(self, rw: bool = False) -> Session: """Open a new PKCS#11 session. Args: rw: Whether to open read-write session Returns: PKCS#11 session Raises: HSMPinError: If PIN authentication fails HSMSessionError: If session cannot be opened """ token = self._get_token() try: # Login is done via token.open() with user_pin or so_pin if self.so_pin: session = token.open(rw=rw, so_pin=self.so_pin) elif self.user_pin: session = token.open(rw=rw, user_pin=self.user_pin) else: session = token.open(rw=rw) return session except pkcs11.PinIncorrect as e: raise HSMPinError("PIN incorrect") from e except pkcs11.PKCS11Error as e: raise HSMSessionError(f"Failed to open session: {e}") from e
[docs] @contextmanager def session(self, rw: bool = False) -> Iterator[Session]: """Context manager for PKCS#11 session. Args: rw: Whether to open read-write session Yields: PKCS#11 session Example: with pool.session() as session: key = session.get_key(...) """ session = self.open_session(rw=rw) try: yield session finally: # python-pkcs11 handles logout automatically when closing # the session that was opened with user_pin/so_pin session.close()
[docs] def get_private_key( self, session: Session, key_type: KeyType, key_id: bytes | None = None, key_label: str | None = None, ) -> pkcs11.PrivateKey: """Get private key from HSM. Args: session: PKCS#11 session key_type: Type of key (RSA, EC, EC_EDWARDS) key_id: Key ID (CKA_ID) key_label: Key label (CKA_LABEL) Returns: PKCS#11 private key object Raises: HSMKeyNotFoundError: If key not found """ try: return session.get_key( key_type=key_type, object_class=ObjectClass.PRIVATE_KEY, id=key_id, label=key_label, ) except pkcs11.NoSuchKey as e: raise HSMKeyNotFoundError( f"Private key not found: id={key_id}, label={key_label}" ) from e
[docs] def get_public_key( self, session: Session, key_type: KeyType, key_id: bytes | None = None, key_label: str | None = None, ) -> pkcs11.PublicKey: """Get public key from HSM. Args: session: PKCS#11 session key_type: Type of key (RSA, EC, EC_EDWARDS) key_id: Key ID (CKA_ID) key_label: Key label (CKA_LABEL) Returns: PKCS#11 public key object Raises: HSMKeyNotFoundError: If key not found """ try: return session.get_key( key_type=key_type, object_class=ObjectClass.PUBLIC_KEY, id=key_id, label=key_label, ) except pkcs11.NoSuchKey as e: raise HSMKeyNotFoundError( f"Public key not found: id={key_id}, label={key_label}" ) from e