Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions mplang/kernels/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ def _get_rng() -> np.random.Generator:
def _keystream(key: bytes, nonce: bytes, length: int) -> bytes:
# WARNING (INSECURE): hash-based keystream (key||nonce||counter)
out = bytearray()
counter = 0
while len(out) < length:
chunk = blake2b(key + nonce + counter.to_bytes(4, "little"))
chunk = blake2b(key + nonce)
out.extend(chunk)
counter += 1
return bytes(out[:length])


Expand All @@ -68,7 +66,7 @@ def _crypto_encrypt(
pt_bytes_np = pt_bytes.to_numpy().astype(np.uint8, copy=False)
key_np = key.to_numpy().astype(np.uint8, copy=False)
rng = _get_rng()
nonce = rng.integers(0, 256, size=(12,), dtype=np.uint8)
nonce = rng.integers(0, 256, size=(16,), dtype=np.uint8)
stream = np.frombuffer(
_keystream(key_np.tobytes(), nonce.tobytes(), pt_bytes_np.size), dtype=np.uint8
)
Expand All @@ -83,8 +81,8 @@ def _crypto_decrypt(
) -> TensorValue:
ct_np = ct_with_nonce.to_numpy().astype(np.uint8, copy=False)
key_np = key.to_numpy().astype(np.uint8, copy=False)
nonce = ct_np[:12]
ct = ct_np[12:]
nonce = ct_np[:16]
ct = ct_np[16:]
stream = np.frombuffer(
_keystream(key_np.tobytes(), nonce.tobytes(), len(ct)), dtype=np.uint8
)
Expand Down
120 changes: 100 additions & 20 deletions mplang/ops/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,37 @@

from __future__ import annotations

from jax.tree_util import PyTreeDef, tree_flatten

from mplang.core import UINT8, TensorType
from mplang.core.mpobject import MPObject
from mplang.core.pfunc import PFunction
from mplang.ops.base import stateless_mod

_CRYPTO_MOD = stateless_mod("crypto")


def _get_algo_overhead(algo: str) -> int:
"""Get ciphertext overhead for a given encryption algorithm.

Args:
algo: Encryption algorithm identifier

Returns:
int: Number of overhead bytes added to plaintext length
"""
overhead_map = {
"aes-ctr": 16, # nonce only (legacy compatibility)
"aes-gcm": 28, # nonce(12) + tag(16) for AES-GCM
"sm4-gcm": 28, # nonce(12) + tag(16) for SM4-GCM
}

if algo not in overhead_map:
# return unknown overhead as -1
return -1
return overhead_map[algo]


@_CRYPTO_MOD.simple_op()
def keygen(*, length: int = 32) -> TensorType:
"""Generate random bytes for symmetric keys or generic randomness.
Expand All @@ -47,40 +72,95 @@ def keygen(*, length: int = 32) -> TensorType:
return TensorType(UINT8, (length,))


@_CRYPTO_MOD.simple_op()
def enc(plaintext: TensorType, key: TensorType) -> TensorType:
"""Symmetric encryption.
@_CRYPTO_MOD.op_def()
def enc(
plaintext: MPObject, key: MPObject, algo: str = "aes-ctr"
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
"""Symmetric encryption with algorithm-aware output sizing.

API: enc(plaintext: u8[N], key: u8[M], *, algo: str = "aes-ctr") -> ciphertext: u8[N + overhead]

API: enc(plaintext: u8[N], key: u8[M]) -> ciphertext: u8[N + 12]
Supported algorithms and overhead:
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)

The algo parameter is stored in the PFunction attributes for backend use.
"""
pt_ty = plaintext
if pt_ty.dtype != UINT8:
raise TypeError("enc expects UINT8 plaintext")
if len(pt_ty.shape) != 1:
raise TypeError("enc expects 1-D plaintext")
length = pt_ty.shape[0]
if length >= 0:
return TensorType(UINT8, (length + 12,))
return TensorType(UINT8, (-1,))


@_CRYPTO_MOD.simple_op()
def dec(ciphertext: TensorType, key: TensorType) -> TensorType:
"""Symmetric decryption.

API: dec(ciphertext: u8[N + 12], key: u8[M]) -> plaintext: u8[N]
# Validate and get overhead for the specified algorithm
overhead = _get_algo_overhead(algo)
length = pt_ty.shape[0]
if length >= 0 and overhead >= 0:
outs_info = (TensorType(UINT8, (length + overhead,)),)
else:
# Unknown length or overhead, return dynamic length
outs_info = (TensorType(UINT8, (-1,)),)

ins_info = (TensorType.from_obj(pt_ty), TensorType.from_obj(key))
pfunc = PFunction(
fn_type="crypto.enc",
ins_info=ins_info,
outs_info=outs_info,
algo=algo,
)
_, treedef = tree_flatten(outs_info[0])
return pfunc, [plaintext, key], treedef


@_CRYPTO_MOD.op_def()
def dec(
ciphertext: MPObject, key: MPObject, algo: str = "aes-ctr"
) -> tuple[PFunction, list[MPObject], PyTreeDef]:
"""Symmetric decryption with algorithm-aware input sizing.

API: dec(ciphertext: u8[N + overhead], key: u8[M], *, algo: str = "aes-ctr") -> plaintext: u8[N]

Supported algorithms and overhead:
- "aes-ctr": 16 bytes (nonce only, legacy compatibility)
- "aes-gcm": 28 bytes (nonce + 16-byte authentication tag)
- "sm4-gcm": 28 bytes (nonce + 16-byte authentication tag)

The algo parameter is stored in the PFunction attributes for backend use.
Backend is responsible for parsing the ciphertext format according to algo.
"""
ct_ty = ciphertext
if ct_ty.dtype != UINT8:
raise TypeError("dec expects UINT8 ciphertext")
if len(ct_ty.shape) != 1:
raise TypeError("dec expects 1-D ciphertext with nonce")
raise TypeError("dec expects 1-D ciphertext")

# Validate and get overhead for the specified algorithm
overhead = _get_algo_overhead(algo)
length = ct_ty.shape[0]
if length >= 0 and length < 12:
raise TypeError("dec expects 1-D ciphertext with nonce")
if length >= 0:
return TensorType(UINT8, (length - 12,))
return TensorType(UINT8, (-1,))

# Validate minimum ciphertext length
if length >= 0 and overhead >= 0 and length < overhead:
raise TypeError(
f"dec expects ciphertext with at least {overhead} bytes for algo='{algo}', but got {length} bytes"
)

# Compute output plaintext length
if length >= 0 and overhead >= 0:
outs_info = (TensorType(UINT8, (length - overhead,)),)
else:
# Unknown length or overhead, return dynamic length
outs_info = (TensorType(UINT8, (-1,)),)

ins_info = (TensorType.from_obj(ct_ty), TensorType.from_obj(key))
pfunc = PFunction(
fn_type="crypto.dec",
ins_info=ins_info,
outs_info=outs_info,
algo=algo,
)
_, treedef = tree_flatten(outs_info[0])
return pfunc, [ciphertext, key], treedef


@_CRYPTO_MOD.simple_op()
Expand Down
Loading
Loading