Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sm2: Fix heap allocation #1099

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions sm2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,19 @@ edition = "2024"
rust-version = "1.85"

[dependencies]
elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features = ["sec1"] }
elliptic-curve = { version = "0.14.0-rc.0", default-features = false, features = [
"sec1",
] }
rand_core = { version = "0.9", default-features = false }

# optional dependencies
primeorder = { version = "=0.14.0-pre.2", optional = true, path = "../primeorder" }
rfc6979 = { version = "=0.5.0-pre.4", optional = true }
serdect = { version = "0.3", optional = true, default-features = false }
signature = { version = "=2.3.0-pre.6", optional = true, features = ["rand_core"] }
signature = { version = "=2.3.0-pre.6", optional = true, features = [
"rand_core",
"digest",
] }
sm3 = { version = "=0.5.0-pre.5", optional = true, default-features = false }

[dev-dependencies]
Expand All @@ -34,10 +39,11 @@ proptest = "1"
rand_core = { version = "0.9", features = ["os_rng"] }

[features]
default = ["arithmetic", "dsa", "pke", "pem", "std"]
alloc = ["elliptic-curve/alloc"]
std = ["alloc", "elliptic-curve/std", "signature?/std"]
default = ["arithmetic", "dsa", "pke", "pem"]
alloc = ["elliptic-curve/alloc", "signature?/alloc", "elliptic-curve/alloc"]
std = ["alloc", "elliptic-curve/std", "signature?/std", "os_rng"]

os_rng = ["rand_core/os_rng"]
arithmetic = ["dep:primeorder", "elliptic-curve/arithmetic"]
bits = ["arithmetic", "elliptic-curve/bits"]
dsa = ["arithmetic", "dep:rfc6979", "dep:signature", "dep:sm3"]
Expand Down
7 changes: 3 additions & 4 deletions sm2/src/dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,11 @@ use signature::{Error, Result, SignatureEncoding};
#[cfg(feature = "alloc")]
use alloc::vec::Vec;

#[cfg(feature = "pkcs8")]
#[cfg(all(feature = "alloc", feature = "pkcs8"))]
use crate::pkcs8::{
AlgorithmIdentifierRef, ObjectIdentifier, der::AnyRef, spki::AssociatedAlgorithmIdentifier,
AlgorithmIdentifierRef, ObjectIdentifier, der, der::AnyRef,
spki::AssociatedAlgorithmIdentifier, spki::SignatureBitStringEncoding,
};
#[cfg(all(feature = "alloc", feature = "pkcs8"))]
use crate::pkcs8::{der, spki::SignatureBitStringEncoding};

/// SM2DSA signature serialized as bytes.
pub type SignatureBytes = [u8; Signature::BYTE_SIZE];
Expand Down
4 changes: 2 additions & 2 deletions sm2/src/dsa/signing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use signature::{
};
use sm3::Sm3;

#[cfg(feature = "pkcs8")]
#[cfg(all(feature = "pkcs8", feature = "alloc"))]
use crate::pkcs8::{
der::AnyRef,
spki::{AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier},
Expand Down Expand Up @@ -230,7 +230,7 @@ fn sign_prehash_rfc6979(secret_scalar: &Scalar, prehash: &[u8], data: &[u8]) ->
Signature::from_scalars(r, s)
}

#[cfg(feature = "pkcs8")]
#[cfg(all(feature = "alloc", feature = "pkcs8"))]
impl SignatureAlgorithmIdentifier for SigningKey {
type Params = AnyRef<'static>;

Expand Down
6 changes: 3 additions & 3 deletions sm2/src/dsa/verifying.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ use sm3::{Sm3, digest::Digest};
use alloc::{boxed::Box, string::String};

#[cfg(all(feature = "alloc", feature = "pkcs8"))]
use crate::pkcs8::{self, EncodePublicKey, spki};
#[cfg(feature = "pkcs8")]
use crate::pkcs8::{
self, EncodePublicKey,
der::AnyRef,
spki,
spki::{AlgorithmIdentifier, AssociatedAlgorithmIdentifier, SignatureAlgorithmIdentifier},
};

Expand Down Expand Up @@ -217,7 +217,7 @@ impl EncodePublicKey for VerifyingKey {
}
}

#[cfg(feature = "pkcs8")]
#[cfg(all(feature = "alloc", feature = "pkcs8"))]
impl SignatureAlgorithmIdentifier for VerifyingKey {
type Params = AnyRef<'static>;

Expand Down
237 changes: 126 additions & 111 deletions sm2/src/pke.rs
Original file line number Diff line number Diff line change
@@ -1,73 +1,52 @@
//! SM2 Encryption Algorithm (SM2) as defined in [draft-shen-sm2-ecdsa § 5].
//!
//! ## Usage
//!
//! NOTE: requires the `sm3` crate for digest functions and the `primeorder` crate for prime field operations.
//!
//! The `DecryptingKey` struct is used for decrypting messages that were encrypted using the SM2 encryption algorithm.
//! It is initialized with a `SecretKey` or a non-zero scalar value and can decrypt ciphertexts using the specified decryption mode.
#![cfg_attr(feature = "std", doc = "```")]
#![cfg_attr(not(feature = "std"), doc = "```ignore")]
//! # fn example() -> Result<(), Box<dyn std::error::Error>> {
//! use rand_core::OsRng; // requires 'os_rng` feature
//! use sm2::{
//! pke::{EncryptingKey, Mode},
//! {SecretKey, PublicKey}
//! };
//!
#![cfg_attr(feature = "alloc", doc = "```")]
#![cfg_attr(not(feature = "alloc"), doc = "```ignore")]
//! use sm2::pke::{EcDecrypt, EcEncrypt, Cipher, Mode};
//! use sm2::SecretKey;
//! use rand_core::OsRng;
//! // Encrypting
//! let secret_key = SecretKey::try_from_rng(&mut OsRng).unwrap(); // serialize with `::to_bytes()`
//! let public_key = secret_key.public_key();
//! let encrypting_key = EncryptingKey::new_with_mode(public_key, Mode::C1C2C3);
//! let plaintext = b"plaintext";
//! let ciphertext = encrypting_key.encrypt(&mut OsRng, plaintext)?;
//! let cipher = public_key.encrypt(plaintext).unwrap();
//! let ciphertext = cipher.to_vec(Mode::C1C3C2);
//!
//! use sm2::pke::DecryptingKey;
//! // Decrypting
//! let decrypting_key = DecryptingKey::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C2C3);
//! assert_eq!(decrypting_key.decrypt(&ciphertext)?, plaintext);
//!
//! // Encrypting ASN.1 DER
//! let ciphertext = encrypting_key.encrypt_der(&mut OsRng, plaintext)?;
//!
//! // Decrypting ASN.1 DER
//! assert_eq!(decrypting_key.decrypt_der(&ciphertext)?, plaintext);
//!
//! Ok(())
//! # }
//! let cipher = Cipher::from_slice(&ciphertext, Mode::C1C3C2).unwrap();
//! let ciphertext = secret_key.decrypt(&cipher).unwrap();
//! assert_eq!(ciphertext, plaintext)
//! ```
//!
//!
//!
//!

use core::cmp::min;

use crate::AffinePoint;

#[cfg(feature = "alloc")]
use alloc::vec;

use elliptic_curve::{
bigint::{Encoding, U256, Uint},
pkcs8::der::{
Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer, asn1::UintRef,
},
CurveArithmetic, Error, FieldBytesSize, Group, PrimeField, Result,
array::typenum::Unsigned,
ops::Reduce,
sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, Tag, ToEncodedPoint},
};

use elliptic_curve::{
Result,
pkcs8::der::{EncodeValue, asn1::OctetStringRef},
sec1::ToEncodedPoint,
};
use sm3::digest::DynDigest;
use primeorder::{AffinePoint, PrimeCurveParams};
use signature::digest::{FixedOutputReset, Output, OutputSizeUser, Update};

#[cfg(feature = "alloc")]
use alloc::{borrow::Cow, vec::Vec};

#[cfg(feature = "arithmetic")]
mod decrypting;
#[cfg(feature = "arithmetic")]
mod encrypting;
use crate::Sm2;
use sm3::Sm3;

#[cfg(feature = "arithmetic")]
pub use self::{decrypting::DecryptingKey, encrypting::EncryptingKey};
pub use self::{decrypting::EcDecrypt, encrypting::EcEncrypt};

/// Modes for the cipher encoding/decoding.
#[derive(Clone, Copy, Debug)]
Expand All @@ -77,101 +56,137 @@ pub enum Mode {
/// new mode
C1C3C2,
}

/// Represents a cipher structure containing encryption-related data (asn.1 format).
///
/// The `Cipher` structure includes the coordinates of the elliptic curve point (`x`, `y`),
/// the digest of the message, and the encrypted cipher text.
pub struct Cipher<'a> {
x: U256,
y: U256,
digest: &'a [u8],
cipher: &'a [u8],
/// TODO: ASN1 Encode and Decode
#[derive(Debug)]
pub struct Cipher<'a, C: CurveArithmetic = Sm2, D: OutputSizeUser = Sm3> {
c1: C::AffinePoint,
#[cfg(feature = "alloc")]
c2: Cow<'a, [u8]>,
#[cfg(not(feature = "alloc"))]
c2: &'a [u8],
c3: Output<D>,
}

impl<'a> Sequence<'a> for Cipher<'a> {}

impl EncodeValue for Cipher<'_> {
fn value_len(&self) -> elliptic_curve::pkcs8::der::Result<Length> {
UintRef::new(&self.x.to_be_bytes())?.encoded_len()?
+ UintRef::new(&self.y.to_be_bytes())?.encoded_len()?
+ OctetStringRef::new(self.digest)?.encoded_len()?
+ OctetStringRef::new(self.cipher)?.encoded_len()?
impl<'a, C, D> Cipher<'a, C, D>
where
C: PrimeCurveParams,
C::AffinePoint: ToEncodedPoint<C> + FromEncodedPoint<C>,
C::FieldBytesSize: ModulusSize,
D: OutputSizeUser,
{
/// Decode from slice
pub fn from_slice(cipher: &'a [u8], mode: Mode) -> Result<Self> {
let tag = Tag::from_u8(cipher.first().cloned().ok_or(Error)?)?;
let c1_len = tag.message_len(C::FieldBytesSize::USIZE);

// B1: get 𝐶1 from 𝐶
let (c1, c) = cipher.split_at(c1_len);
// verify that point c1 satisfies the elliptic curve
let encoded_c1 = EncodedPoint::<C>::from_bytes(c1)?;
let c1 = Option::from(C::AffinePoint::from_encoded_point(&encoded_c1)).ok_or(Error)?;
// B2: compute point 𝑆 = [ℎ]𝐶1
let scalar: C::Scalar = Reduce::<C::Uint>::reduce(C::Uint::from(C::FieldElement::S));
let s: C::ProjectivePoint = c1 * scalar;
if s.is_identity().into() {
return Err(Error);
}

let digest_size = D::output_size();
let (c2, c3_buf) = match mode {
Mode::C1C3C2 => {
let (c3, c2) = c.split_at(digest_size);
(c2, c3)
}
Mode::C1C2C3 => c.split_at(c.len() - digest_size),
};

let mut c3 = Output::<D>::default();
c3.clone_from_slice(c3_buf);

#[cfg(feature = "alloc")]
let c2 = Cow::Borrowed(c2);

Ok(Self { c1, c2, c3 })
}

fn encode_value(&self, writer: &mut impl Writer) -> elliptic_curve::pkcs8::der::Result<()> {
UintRef::new(&self.x.to_be_bytes())?.encode(writer)?;
UintRef::new(&self.y.to_be_bytes())?.encode(writer)?;
OctetStringRef::new(self.digest)?.encode(writer)?;
OctetStringRef::new(self.cipher)?.encode(writer)?;
Ok(())
/// Encode to Vec
#[cfg(feature = "alloc")]
pub fn to_vec(&self, mode: Mode) -> Vec<u8> {
let point = self.c1.to_encoded_point(false);
let len = point.len() + self.c2.len() + self.c3.len();
let mut result = Vec::with_capacity(len);
match mode {
Mode::C1C2C3 => {
result.extend(point.as_ref());
result.extend(self.c2.as_ref());
result.extend(&self.c3);
}
Mode::C1C3C2 => {
result.extend(point.as_ref());
result.extend(&self.c3);
result.extend(self.c2.as_ref());
}
}

result
}
}

impl<'a> DecodeValue<'a> for Cipher<'a> {
type Error = elliptic_curve::pkcs8::der::Error;

fn decode_value<R: Reader<'a>>(
decoder: &mut R,
header: elliptic_curve::pkcs8::der::Header,
) -> core::result::Result<Self, Self::Error> {
decoder.read_nested(header.length, |nr| {
let x = UintRef::decode(nr)?.as_bytes();
let y = UintRef::decode(nr)?.as_bytes();
let digest = OctetStringRef::decode(nr)?.into();
let cipher = OctetStringRef::decode(nr)?.into();
Ok(Cipher {
x: Uint::from_be_bytes(zero_pad_byte_slice(x)?),
y: Uint::from_be_bytes(zero_pad_byte_slice(y)?),
digest,
cipher,
})
})
/// Get C1
pub fn c1(&self) -> &C::AffinePoint {
&self.c1
}
/// Get C2
pub fn c2(&self) -> &[u8] {
#[cfg(feature = "alloc")]
return &self.c2;
#[cfg(not(feature = "alloc"))]
return self.c2;
}
/// Get C3
pub fn c3(&self) -> &Output<D> {
&self.c3
}
}

/// Performs key derivation using a hash function and elliptic curve point.
fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> {
let klen = c2.len();
/// Performs key derivation using a hash function and elliptic curve point.
/// Magic modification: Does it support streaming encryption and decryption?
fn kdf<D, C>(hasher: &mut D, kpb: AffinePoint<C>, msg: &[u8], c2_out: &mut [u8]) -> Result<()>
where
D: Update + FixedOutputReset,
C: CurveArithmetic + PrimeCurveParams,
FieldBytesSize<C>: ModulusSize,
AffinePoint<C>: ToEncodedPoint<C>,
{
let klen = msg.len();
let mut ct: i32 = 0x00000001;
let mut offset = 0;
let digest_size = hasher.output_size();
let mut ha = vec![0u8; digest_size];
let digest_size = D::output_size();
let mut ha = Output::<D>::default();
let encode_point = kpb.to_encoded_point(false);

hasher.reset();
while offset < klen {
hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?);
hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?);
hasher.update(encode_point.x().ok_or(Error)?);
hasher.update(encode_point.y().ok_or(Error)?);
hasher.update(&ct.to_be_bytes());

hasher
.finalize_into_reset(&mut ha)
.map_err(|_e| elliptic_curve::Error)?;
hasher.finalize_into_reset(&mut ha);

let xor_len = min(digest_size, klen - offset);
xor(c2, &ha, offset, xor_len);
xor(msg, c2_out, &ha, offset, xor_len);
offset += xor_len;
ct += 1;
}
Ok(())
}

/// XORs a portion of the buffer `c2` with a hash value.
fn xor(c2: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) {
fn xor(msg: &[u8], c2_out: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) {
for i in 0..xor_len {
c2[offset + i] ^= ha[i];
c2_out[offset + i] = msg[offset + i] ^ ha[i];
}
}

/// Converts a byte slice to a fixed-size array, padding with leading zeroes if necessary.
pub(crate) fn zero_pad_byte_slice<const N: usize>(
bytes: &[u8],
) -> elliptic_curve::pkcs8::der::Result<[u8; N]> {
let num_zeroes = N
.checked_sub(bytes.len())
.ok_or_else(|| Tag::Integer.length_error())?;

// Copy input into `N`-sized output buffer with leading zeroes
let mut output = [0u8; N];
output[num_zeroes..].copy_from_slice(bytes);
Ok(output)
}
Loading