Source code for hsmkey.jwk_integration

"""JWCrypto integration for HSM-backed keys.

This module provides JWK (JSON Web Key) support for HSM-backed keys,
allowing seamless use of HSM keys with jwcrypto for JWS and JWE operations.

Example usage:

    from hsmkey import SessionPool
    from hsmkey.jwk_integration import HSMJWK
    from jwcrypto import JWS

    pool = SessionPool(
        module_path="/usr/lib/softhsm/libsofthsm2.so",
        token_label="my-token",
        user_pin="123456",
    )

    with pool.session() as session:
        # Create JWK from HSM key
        key = HSMJWK.from_hsm(session, key_label="rsa-2048")

        # Sign with JWS (signing happens on HSM)
        jws = JWS(b'{"sub": "user@example.com"}')
        jws.add_signature(key, alg='RS256', protected='{"typ":"JWT"}')
        token = jws.serialize(compact=True)
"""

from __future__ import annotations

import base64
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Iterator, Union

from cryptography.hazmat.primitives.asymmetric import ec, ed25519, ed448, rsa
from jwcrypto.jwk import JWK
from pkcs11 import Attribute, KeyType, ObjectClass

from .algorithms import CURVE_ALIASES, OID_TO_CURVE
from .exceptions import HSMKeyNotFoundError, HSMSessionError
from .keys import (
    PKCS11EllipticCurvePrivateKey,
    PKCS11EllipticCurvePublicKey,
    PKCS11Ed25519PrivateKey,
    PKCS11Ed25519PublicKey,
    PKCS11Ed448PrivateKey,
    PKCS11Ed448PublicKey,
    PKCS11RSAPrivateKey,
    PKCS11RSAPublicKey,
)
from .session import SessionPool

if TYPE_CHECKING:
    from pkcs11 import Session

# Type alias for all supported HSM private key types
HSMPrivateKey = Union[
    PKCS11RSAPrivateKey,
    PKCS11EllipticCurvePrivateKey,
    PKCS11Ed25519PrivateKey,
    PKCS11Ed448PrivateKey,
]

# Type alias for all supported HSM public key types
HSMPublicKey = Union[
    PKCS11RSAPublicKey,
    PKCS11EllipticCurvePublicKey,
    PKCS11Ed25519PublicKey,
    PKCS11Ed448PublicKey,
]

# Curve name to JWK 'crv' parameter mapping
CURVE_TO_JWK_CRV: dict[str, str] = {
    "secp256r1": "P-256",
    "secp384r1": "P-384",
    "secp521r1": "P-521",
    "secp256k1": "secp256k1",
    "brainpoolP256r1": "BP-256",
    "brainpoolP384r1": "BP-384",
    "brainpoolP512r1": "BP-512",
}

# EdDSA OIDs
ED25519_OID = bytes.fromhex("06032b6570")  # 1.3.101.112
ED448_OID = bytes.fromhex("06032b6571")  # 1.3.101.113


def _base64url_encode(data: bytes) -> str:
    """Encode bytes to base64url string without padding."""
    return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii")


def _int_to_bytes(value: int, length: int | None = None) -> bytes:
    """Convert integer to big-endian bytes.

    Args:
        value: Integer to convert
        length: Optional fixed length (pads with zeros if needed)

    Returns:
        Big-endian byte representation
    """
    if length is None:
        length = (value.bit_length() + 7) // 8
    return value.to_bytes(length, "big")


[docs] class HSMJWK(JWK): """JWK backed by HSM keys. This class extends jwcrypto's JWK to support HSM-backed keys. All cryptographic operations (signing, decryption) are performed on the HSM, while public key parameters are extracted for JWK representation. The private key material never leaves the HSM. Attributes: _hsm_private_key: The HSM private key object _hsm_public_key: The HSM public key object _session: The PKCS#11 session _key_id: PKCS#11 key ID _key_label: PKCS#11 key label """ def __init__( self, session: Session | None = None, key_id: bytes | None = None, key_label: str | None = None, **kwargs: Any, ) -> None: """Initialize HSM JWK. Args: session: PKCS#11 session (optional, can be set later) key_id: Key ID (CKA_ID) key_label: Key label (CKA_LABEL) **kwargs: Additional JWK parameters """ self._session = session self._key_id = key_id self._key_label = key_label self._hsm_private_key: HSMPrivateKey | None = None self._hsm_public_key: HSMPublicKey | None = None super().__init__(**kwargs)
[docs] @classmethod def from_hsm( cls, session: Session, key_id: bytes | None = None, key_label: str | None = None, kid: str | None = None, use: str | None = None, key_ops: list[str] | None = None, ) -> "HSMJWK": """Create a JWK from an HSM key. This factory method loads a key from the HSM and creates a JWK representation with the public key parameters. Args: session: PKCS#11 session key_id: Key ID (CKA_ID) key_label: Key label (CKA_LABEL) kid: JWK Key ID to assign use: Key use ('sig' or 'enc') key_ops: Allowed key operations Returns: HSMJWK instance backed by the HSM key Raises: HSMKeyNotFoundError: If the key is not found ValueError: If key type is not supported """ if key_id is None and key_label is None: raise ValueError("Either key_id or key_label must be provided") # Try to find the private key and determine its type hsm_key, key_type = cls._load_hsm_key(session, key_id, key_label) # Build JWK parameters based on key type jwk_params = cls._extract_jwk_params(hsm_key, key_type) # Add optional parameters if kid is not None: jwk_params["kid"] = kid if use is not None: jwk_params["use"] = use if key_ops is not None: jwk_params["key_ops"] = key_ops # Create instance instance = cls( session=session, key_id=key_id, key_label=key_label, **jwk_params, ) # Cache the loaded key instance._hsm_private_key = hsm_key instance._hsm_public_key = hsm_key.public_key() return instance
@classmethod def _load_hsm_key( cls, session: Session, key_id: bytes | None, key_label: str | None, ) -> tuple[HSMPrivateKey, str]: """Load HSM key and determine its type. Args: session: PKCS#11 session key_id: Key ID key_label: Key label Returns: Tuple of (HSM private key, key type string) Raises: HSMKeyNotFoundError: If key not found """ # Try RSA first try: key = PKCS11RSAPrivateKey(session, key_id, key_label) # Access property to trigger key loading _ = key.key_size return key, "RSA" except (HSMKeyNotFoundError, Exception): pass # Try EC try: key = PKCS11EllipticCurvePrivateKey(session, key_id, key_label) _ = key.curve return key, "EC" except (HSMKeyNotFoundError, Exception): pass # Try Ed25519 try: key = PKCS11Ed25519PrivateKey(session, key_id, key_label) # Verify it's actually Ed25519 by checking the curve ec_params = bytes(key.pkcs11_public_key[Attribute.EC_PARAMS]) if ec_params == ED25519_OID: return key, "Ed25519" except (HSMKeyNotFoundError, Exception): pass # Try Ed448 try: key = PKCS11Ed448PrivateKey(session, key_id, key_label) ec_params = bytes(key.pkcs11_public_key[Attribute.EC_PARAMS]) if ec_params == ED448_OID: return key, "Ed448" except (HSMKeyNotFoundError, Exception): pass raise HSMKeyNotFoundError( f"Key not found: id={key_id}, label={key_label}" ) @classmethod def _extract_jwk_params( cls, hsm_key: HSMPrivateKey, key_type: str, ) -> dict[str, Any]: """Extract JWK parameters from HSM key. Args: hsm_key: HSM private key key_type: Key type string ("RSA", "EC", "Ed25519", "Ed448") Returns: Dictionary of JWK parameters """ if key_type == "RSA": assert isinstance(hsm_key, PKCS11RSAPrivateKey) return cls._extract_rsa_params(hsm_key) elif key_type == "EC": assert isinstance(hsm_key, PKCS11EllipticCurvePrivateKey) return cls._extract_ec_params(hsm_key) elif key_type == "Ed25519": assert isinstance(hsm_key, PKCS11Ed25519PrivateKey) return cls._extract_ed25519_params(hsm_key) elif key_type == "Ed448": assert isinstance(hsm_key, PKCS11Ed448PrivateKey) return cls._extract_ed448_params(hsm_key) else: raise ValueError(f"Unsupported key type: {key_type}") @classmethod def _extract_rsa_params(cls, hsm_key: PKCS11RSAPrivateKey) -> dict[str, Any]: """Extract RSA JWK parameters. Args: hsm_key: RSA private key Returns: JWK parameters for RSA key """ pub_key = hsm_key.public_key() numbers = pub_key.public_numbers() # Convert to base64url n_bytes = _int_to_bytes(numbers.n) e_bytes = _int_to_bytes(numbers.e) return { "kty": "RSA", "n": _base64url_encode(n_bytes), "e": _base64url_encode(e_bytes), } @classmethod def _extract_ec_params( cls, hsm_key: PKCS11EllipticCurvePrivateKey ) -> dict[str, Any]: """Extract EC JWK parameters. Args: hsm_key: EC private key Returns: JWK parameters for EC key """ pub_key = hsm_key.public_key() numbers = pub_key.public_numbers() # Get curve name and map to JWK crv curve_name = numbers.curve.name # Normalize curve name curve_name = CURVE_ALIASES.get(curve_name, curve_name) crv = CURVE_TO_JWK_CRV.get(curve_name, curve_name) # Determine coordinate size based on curve key_size = hsm_key.key_size coord_size = (key_size + 7) // 8 # Convert coordinates to fixed-size bytes x_bytes = _int_to_bytes(numbers.x, coord_size) y_bytes = _int_to_bytes(numbers.y, coord_size) return { "kty": "EC", "crv": crv, "x": _base64url_encode(x_bytes), "y": _base64url_encode(y_bytes), } @classmethod def _extract_ed25519_params( cls, hsm_key: PKCS11Ed25519PrivateKey ) -> dict[str, Any]: """Extract Ed25519 JWK parameters. Args: hsm_key: Ed25519 private key Returns: JWK parameters for Ed25519 key (OKP type) """ pub_key = hsm_key.public_key() x_bytes = pub_key.public_bytes_raw() return { "kty": "OKP", "crv": "Ed25519", "x": _base64url_encode(x_bytes), } @classmethod def _extract_ed448_params(cls, hsm_key: PKCS11Ed448PrivateKey) -> dict[str, Any]: """Extract Ed448 JWK parameters. Args: hsm_key: Ed448 private key Returns: JWK parameters for Ed448 key (OKP type) """ pub_key = hsm_key.public_key() x_bytes = pub_key.public_bytes_raw() return { "kty": "OKP", "crv": "Ed448", "x": _base64url_encode(x_bytes), } def _get_hsm_private_key(self) -> HSMPrivateKey: """Get or load the HSM private key. Returns: HSM private key Raises: HSMSessionError: If session is not available HSMKeyNotFoundError: If key not found """ if self._hsm_private_key is not None: return self._hsm_private_key if self._session is None: raise HSMSessionError("No HSM session available") self._hsm_private_key, _ = self._load_hsm_key( self._session, self._key_id, self._key_label ) return self._hsm_private_key def _get_hsm_public_key(self) -> HSMPublicKey: """Get or load the HSM public key. Returns: HSM public key """ if self._hsm_public_key is not None: return self._hsm_public_key # Get from private key private_key = self._get_hsm_private_key() self._hsm_public_key = private_key.public_key() return self._hsm_public_key
[docs] def get_op_key( self, operation: str | None = None, arg: Any = None, ) -> Any: """Return the key object for the specified operation. This method is called by jwcrypto's JWS and JWE implementations to get the actual key for cryptographic operations. For HSM keys: - Sign, decrypt, unwrapKey: Returns HSM private key - Verify, encrypt, wrapKey: Returns HSM public key Args: operation: The operation to perform ('sign', 'verify', etc.) arg: Optional argument (algorithm, etc.) Returns: HSM key object (compatible with cryptography library interfaces) """ if operation in ("sign", "decrypt", "unwrapKey"): return self._get_hsm_private_key() elif operation in ("verify", "encrypt", "wrapKey"): return self._get_hsm_public_key() else: # For unknown operations, try to return private key # or fall back to parent implementation try: return self._get_hsm_private_key() except (HSMSessionError, HSMKeyNotFoundError): return super().get_op_key(operation, arg)
[docs] def has_private(self) -> bool: """Check if this JWK has a private key. For HSM keys, the private key exists on the HSM but cannot be exported. Returns: True (HSM keys always have private key) """ return True
[docs] def export_private(self, as_dict: bool = False) -> dict | str: """Export private key. For HSM keys, this raises an error since private keys cannot leave the HSM. Raises: HSMSessionError: Always raised for HSM keys """ raise HSMSessionError( "Cannot export private key from HSM. " "Private key material is protected and cannot leave the HSM." )
[docs] def export_public(self, as_dict: bool = False) -> dict | str: """Export public key. Returns the public key parameters in JWK format. Args: as_dict: If True, return as dictionary; otherwise return JSON string Returns: Public key in JWK format """ return super().export_public(as_dict=as_dict)
[docs] class HSMJWKSet: """JWK Set backed by HSM keys. Manages a collection of HSM-backed JWKs for use with jwcrypto. """ def __init__(self, session: Session) -> None: """Initialize HSM JWK Set. Args: session: PKCS#11 session """ self._session = session self._keys: dict[str, HSMJWK] = {}
[docs] def add_key( self, key_id: bytes | None = None, key_label: str | None = None, kid: str | None = None, use: str | None = None, key_ops: list[str] | None = None, ) -> HSMJWK: """Add an HSM key to the set. Args: key_id: PKCS#11 key ID key_label: PKCS#11 key label kid: JWK Key ID to assign use: Key use ('sig' or 'enc') key_ops: Allowed key operations Returns: The created HSMJWK Raises: ValueError: If kid is not unique """ jwk = HSMJWK.from_hsm( self._session, key_id=key_id, key_label=key_label, kid=kid, use=use, key_ops=key_ops, ) # Use provided kid or generate one actual_kid = kid or key_label or (key_id.hex() if key_id else None) if actual_kid is None: raise ValueError("Must provide kid, key_label, or key_id") if actual_kid in self._keys: raise ValueError(f"Key ID already exists: {actual_kid}") self._keys[actual_kid] = jwk return jwk
[docs] def get_key(self, kid: str) -> HSMJWK | None: """Get key by Key ID. Args: kid: JWK Key ID Returns: HSMJWK if found, None otherwise """ return self._keys.get(kid)
def __iter__(self) -> Iterator[HSMJWK]: """Iterate over keys in the set.""" return iter(self._keys.values()) def __len__(self) -> int: """Return number of keys in the set.""" return len(self._keys)
[docs] def jwk_from_hsm( session: Session, key_id: bytes | None = None, key_label: str | None = None, kid: str | None = None, use: str | None = None, key_ops: list[str] | None = None, ) -> HSMJWK: """Create a JWK from an HSM key. Convenience function that wraps HSMJWK.from_hsm(). Args: session: PKCS#11 session key_id: Key ID (CKA_ID) key_label: Key label (CKA_LABEL) kid: JWK Key ID to assign use: Key use ('sig' or 'enc') key_ops: Allowed key operations Returns: HSMJWK instance backed by the HSM key Example: with pool.session() as session: key = jwk_from_hsm(session, key_label="rsa-2048") """ return HSMJWK.from_hsm( session, key_id=key_id, key_label=key_label, kid=kid, use=use, key_ops=key_ops, )
[docs] @contextmanager def hsm_session( module_path: str, token_label: str, pin: str, ) -> Iterator[Session]: """Context manager for HSM session. Convenience function for managing HSM sessions. Args: module_path: Path to PKCS#11 library token_label: Token label pin: User PIN Yields: PKCS#11 session Example: with hsm_session("/usr/lib/softhsm/libsofthsm2.so", "my-token", "1234") as session: key = jwk_from_hsm(session, key_label="my-key") """ pool = SessionPool( module_path=module_path, token_label=token_label, user_pin=pin, ) with pool.session() as session: yield session