"""Elliptic Curve key implementations backed by HSM."""
from __future__ import annotations
from typing import TYPE_CHECKING
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.asymmetric.utils import (
Prehashed,
decode_dss_signature,
encode_dss_signature,
)
from pkcs11 import Attribute, KeyType, Mechanism
from ..algorithms import OID_TO_CURVE, normalize_curve_name
from ..exceptions import HSMOperationError, HSMUnsupportedError
from .base import PKCS11PrivateKeyMixin
if TYPE_CHECKING:
from pkcs11 import Session
# Mapping of curve names to cryptography curve classes
CURVE_CLASSES: dict[str, type[ec.EllipticCurve]] = {
"secp256r1": ec.SECP256R1,
"secp384r1": ec.SECP384R1,
"secp521r1": ec.SECP521R1,
"secp256k1": ec.SECP256K1,
"brainpoolP256r1": ec.BrainpoolP256R1,
"brainpoolP384r1": ec.BrainpoolP384R1,
"brainpoolP512r1": ec.BrainpoolP512R1,
}
# Key sizes for each curve (in bytes for signature components)
CURVE_KEY_SIZES: dict[str, int] = {
"secp256r1": 32,
"secp384r1": 48,
"secp521r1": 66,
"secp256k1": 32,
"brainpoolP256r1": 32,
"brainpoolP320r1": 40,
"brainpoolP384r1": 48,
"brainpoolP512r1": 64,
}
def _int_to_bytes(value: int, length: int) -> bytes:
"""Convert integer to fixed-length big-endian bytes."""
return value.to_bytes(length, "big")
def _bytes_to_int(data: bytes) -> int:
"""Convert big-endian bytes to integer."""
return int.from_bytes(data, "big")
[docs]
class PKCS11EllipticCurvePrivateKey(PKCS11PrivateKeyMixin, ec.EllipticCurvePrivateKey):
"""Elliptic Curve private key backed by HSM.
This class implements the cryptography library's EllipticCurvePrivateKey
interface while performing all cryptographic operations on the HSM.
"""
_key_type = KeyType.EC
def __init__(
self,
session: Session,
key_id: bytes | None = None,
key_label: str | None = None,
) -> None:
"""Initialize HSM EC private key.
Args:
session: PKCS#11 session
key_id: Key ID (CKA_ID)
key_label: Key label (CKA_LABEL)
"""
PKCS11PrivateKeyMixin.__init__(self, session, key_id, key_label)
self._curve: ec.EllipticCurve | None = None
@property
def curve(self) -> ec.EllipticCurve:
"""Return the curve used by this key."""
if self._curve is None:
ec_params = bytes(self.pkcs11_public_key[Attribute.EC_PARAMS])
curve_name = OID_TO_CURVE.get(ec_params)
if curve_name is None:
# Try to find by substring match for some HSM implementations
for oid, name in OID_TO_CURVE.items():
if oid in ec_params or ec_params in oid:
curve_name = name
break
if curve_name is None:
raise ValueError(f"Unknown curve OID: {ec_params.hex()}")
curve_class = CURVE_CLASSES.get(curve_name)
if curve_class is None:
raise ValueError(f"Unsupported curve: {curve_name}")
self._curve = curve_class()
return self._curve
@property
def key_size(self) -> int:
"""Return key size in bits."""
return self.curve.key_size
[docs]
def public_key(self) -> "PKCS11EllipticCurvePublicKey":
"""Return the public key corresponding to this private key."""
return PKCS11EllipticCurvePublicKey.from_pkcs11_key(
self._session,
self.pkcs11_public_key,
self._key_id,
self._key_label,
)
[docs]
def sign(
self,
data: bytes,
signature_algorithm: ec.EllipticCurveSignatureAlgorithm,
) -> bytes:
"""Sign data using ECDSA.
Args:
data: Data to sign
signature_algorithm: Signature algorithm (ECDSA with hash)
Returns:
DER-encoded signature
Raises:
HSMOperationError: If signing fails
"""
if not isinstance(signature_algorithm, ec.ECDSA):
raise ValueError(f"Unsupported signature algorithm: {type(signature_algorithm)}")
algorithm = signature_algorithm.algorithm
# Pre-hash the data if not already prehashed
if isinstance(algorithm, Prehashed):
hash_data = data
else:
from cryptography.hazmat.primitives.hashes import Hash
from cryptography.hazmat.backends import default_backend
h = Hash(algorithm, backend=default_backend())
h.update(data)
hash_data = h.finalize()
try:
# PKCS#11 ECDSA returns raw signature (r || s)
raw_signature = self.pkcs11_private_key.sign(
hash_data,
mechanism=Mechanism.ECDSA,
)
raw_signature = bytes(raw_signature)
# Convert raw signature to DER format
# Raw signature is r || s, each component is curve key size bytes
curve_name = normalize_curve_name(self.curve.name)
component_size = CURVE_KEY_SIZES.get(curve_name, len(raw_signature) // 2)
r = _bytes_to_int(raw_signature[:component_size])
s = _bytes_to_int(raw_signature[component_size:])
return encode_dss_signature(r, s)
except Exception as e:
raise HSMOperationError(f"ECDSA signing failed: {e}") from e
[docs]
def exchange(
self,
algorithm: ec.ECDH,
peer_public_key: ec.EllipticCurvePublicKey,
) -> bytes:
"""Perform ECDH key exchange.
Args:
algorithm: ECDH algorithm
peer_public_key: Peer's public key
Returns:
Shared secret
Raises:
HSMOperationError: If key exchange fails
"""
try:
# Get peer public key in uncompressed point format
peer_bytes = peer_public_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)
# Perform ECDH on HSM
shared_secret = self.pkcs11_private_key.derive_key(
KeyType.GENERIC_SECRET,
self.key_size // 8,
mechanism=Mechanism.ECDH1_DERIVE,
mechanism_param=peer_bytes,
)
# Extract the key value
return bytes(shared_secret[Attribute.VALUE])
except Exception as e:
raise HSMOperationError(f"ECDH key exchange failed: {e}") from e
[docs]
def private_numbers(self) -> ec.EllipticCurvePrivateNumbers:
"""Not supported for HSM keys."""
self._raise_unsupported("private_numbers()")
[docs]
def private_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PrivateFormat,
encryption_algorithm: serialization.KeySerializationEncryption,
) -> bytes:
"""Not supported for HSM keys."""
self._raise_unsupported("private_bytes()")
[docs]
def private_bytes_raw(self) -> bytes:
"""Not supported for HSM keys."""
self._raise_unsupported("private_bytes_raw()")
[docs]
class PKCS11EllipticCurvePublicKey(ec.EllipticCurvePublicKey):
"""Elliptic Curve public key from HSM.
This class wraps the public key data extracted from HSM and provides
the standard cryptography library interface.
"""
def __init__(
self,
curve: ec.EllipticCurve,
x: int,
y: int,
) -> None:
"""Initialize EC public key.
Args:
curve: Elliptic curve
x: X coordinate
y: Y coordinate
"""
self._curve = curve
self._x = x
self._y = y
# Create internal cryptography public key for operations
self._crypto_key = ec.EllipticCurvePublicNumbers(
x=x,
y=y,
curve=curve,
).public_key()
[docs]
@classmethod
def from_pkcs11_key(
cls,
session: "Session",
pkcs11_key,
key_id: bytes | None = None,
key_label: str | None = None,
) -> "PKCS11EllipticCurvePublicKey":
"""Create public key from PKCS#11 public key object.
Args:
session: PKCS#11 session
pkcs11_key: PKCS#11 public key object
key_id: Key ID
key_label: Key label
Returns:
PKCS11EllipticCurvePublicKey instance
"""
# Get curve from EC_PARAMS
ec_params = bytes(pkcs11_key[Attribute.EC_PARAMS])
curve_name = OID_TO_CURVE.get(ec_params)
if curve_name is None:
# Try substring match
for oid, name in OID_TO_CURVE.items():
if oid in ec_params or ec_params in oid:
curve_name = name
break
if curve_name is None:
raise ValueError(f"Unknown curve OID: {ec_params.hex()}")
curve_class = CURVE_CLASSES.get(curve_name)
if curve_class is None:
raise ValueError(f"Unsupported curve: {curve_name}")
curve = curve_class()
# Get EC_POINT (uncompressed point format: 04 || x || y)
ec_point = bytes(pkcs11_key[Attribute.EC_POINT])
# EC_POINT in PKCS#11 is typically wrapped in a DER OCTET STRING
# Format: 04 <length> 04 <x> <y>
# where 04 is OCTET STRING tag, <length> is the length byte(s),
# and the content is the uncompressed point (04 || x || y)
point_data = ec_point
# Check if wrapped in OCTET STRING (tag 0x04)
if point_data[0] == 0x04 and len(point_data) > 2:
# Parse DER length
length_byte = point_data[1]
if length_byte < 0x80:
# Short form length
content_start = 2
elif length_byte == 0x81:
# Long form, 1 length byte
content_start = 3
elif length_byte == 0x82:
# Long form, 2 length bytes
content_start = 4
else:
content_start = 0 # Assume no wrapping
if content_start > 0:
inner_data = point_data[content_start:]
# Check if inner data starts with uncompressed point marker
if inner_data[0] == 0x04:
point_data = inner_data
# Now point_data should be: 04 || x || y
if point_data[0] != 0x04:
raise ValueError(f"Expected uncompressed point format, got: {point_data[:5].hex()}")
# Remove the 0x04 marker
coord_bytes = point_data[1:]
# Split into x and y coordinates
coord_size = len(coord_bytes) // 2
x = _bytes_to_int(coord_bytes[:coord_size])
y = _bytes_to_int(coord_bytes[coord_size:])
return cls(curve, x, y)
@property
def curve(self) -> ec.EllipticCurve:
"""Return the curve."""
return self._curve
@property
def key_size(self) -> int:
"""Return key size in bits."""
return self._curve.key_size
[docs]
def public_numbers(self) -> ec.EllipticCurvePublicNumbers:
"""Return EC public numbers."""
return ec.EllipticCurvePublicNumbers(
x=self._x,
y=self._y,
curve=self._curve,
)
[docs]
def public_bytes(
self,
encoding: serialization.Encoding,
format: serialization.PublicFormat,
) -> bytes:
"""Serialize public key."""
return self._crypto_key.public_bytes(encoding, format)
[docs]
def public_bytes_raw(self) -> bytes:
"""Return raw public key bytes (uncompressed point without 0x04 prefix)."""
return self._crypto_key.public_bytes(
serialization.Encoding.X962,
serialization.PublicFormat.UncompressedPoint,
)[1:] # Remove 0x04 prefix
[docs]
def verify(
self,
signature: bytes,
data: bytes,
signature_algorithm: ec.EllipticCurveSignatureAlgorithm,
) -> None:
"""Verify a signature.
Args:
signature: DER-encoded signature
data: Original data
signature_algorithm: Signature algorithm
Raises:
InvalidSignature: If verification fails
"""
self._crypto_key.verify(signature, data, signature_algorithm)
def __eq__(self, other: object) -> bool:
"""Check equality."""
if not isinstance(other, (PKCS11EllipticCurvePublicKey, ec.EllipticCurvePublicKey)):
return False
other_numbers = other.public_numbers()
return (
self._x == other_numbers.x
and self._y == other_numbers.y
and self._curve.name == other_numbers.curve.name
)
def __hash__(self) -> int:
"""Hash based on key parameters."""
return hash((self._x, self._y, self._curve.name))
def __copy__(self) -> "PKCS11EllipticCurvePublicKey":
"""Create a copy."""
return PKCS11EllipticCurvePublicKey(self._curve, self._x, self._y)
def __deepcopy__(self, memo: dict) -> "PKCS11EllipticCurvePublicKey":
"""Create a deep copy."""
return self.__copy__()