Skip to content

Commit b6c96dc

Browse files
dd23kc1212fegmorte
authored
feat: add extra_data to core_client decryption (#470)
* chore: change public decryption extra_data to &[u8] instead of Vec<u8> * chore: pass extra_data from request in core-client * chore: properly propagate extra_data * chore: fix integration test build * chore: fix test * chore: more tests on extra_data * chore: bump aws-lc-rs * chore: update cargo audit * ci: upgrade zizmor action * chore: bump zismor version * chore: unify parsing of extra_data * chore: fix typos in docstrings * chore: remove explicit zizmor version in workflow * doc: update cargo audit comment --------- Co-authored-by: Kelong Cong <kc1212@users.noreply.github.com> Co-authored-by: Frederic Egmorte <frederic.egmorte@zama.ai>
1 parent 06095c2 commit b6c96dc

33 files changed

+229
-55
lines changed

.cargo/audit.toml

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,17 @@
22
# permanently specified in this file.
33

44
[advisories]
5-
# The ignored vulnerability RUSTSEC-2023-0071 is not applicable in our use-case
6-
ignore = ["RUSTSEC-2023-0071"]
5+
# RUSTSEC-2023-0071 is not applicable in our use-case.
6+
# RUSTSEC-2026-0049 is not an urgent fix for us. It can only be triggered
7+
# when talking to AWS APIs by allowing the attacker to trick us into accepting
8+
# expired certificates. If AWS's TLS certificates are stolen, the consequences
9+
# are of a different magnitude and not something we can do much about.
10+
# For the time being we also require the "connector-hyper-0-14-x" feature
11+
# because we need to specify fixed server names for TLS connections (because
12+
# when socat proxies are used, the connection goes to localhost and if we don't
13+
# overload the server names, the handshake would fail because localhost isn't
14+
# `sts.*.amazonaws.com`).
15+
ignore = ["RUSTSEC-2023-0071", "RUSTSEC-2026-0049"]
716
informational_warnings = ["unmaintained"]
817
severity_threshold = "medium"
918

.cargo/deny.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ allow = [
100100
"ISC",
101101
"MIT",
102102
"MPL-2.0",
103-
"OpenSSL",
103+
# "OpenSSL", no longer used
104104
# "Unicode-DFS-2016", no longer used
105105
"Unicode-3.0",
106106
"Unlicense",

.github/workflows/ci_lint.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ jobs:
5151
persist-credentials: false
5252

5353
- name: Run zizmor
54-
uses: zizmorcore/zizmor-action@e673c3917a1aef3c65c972347ed84ccd013ecda4 # v0.2.0
54+
uses: zizmorcore/zizmor-action@71321a20a9ded102f6e9ce5718a2fcec2c4f70d8 # v0.5.2
5555
with:
5656
persona: pedantic
57-
version: 1.15.2

Cargo.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ async-trait = "=0.1.89" # Async trait support - MEDIUM RISK: Reputable individ
9191
async_cell = "0.2.2" # Async cell implementation - HIGH RISK: Individual maintainer, very low popularity
9292
attestation-doc-validation = { version = "=0.10.0" } # AWS Nitro attestation validation - LOW RISK: Evervault (reputable security company), security-critical but trusted
9393
aws-config = { version = "=1.8.12" } # AWS SDK configuration - LOW RISK: Official AWS SDK, actively maintained
94-
aws-lc-rs = { version = "=1.16.1" } # AWS-LC dependency pinned because of an earlier version vulnerability - LOW RISK: Official AWS SDK
94+
aws-lc-rs = { version = "=1.16.2" } # AWS-LC dependency pinned because of an earlier version vulnerability - LOW RISK: Official AWS SDK
9595
aws-nitro-enclaves-nsm-api = { version = "=0.4.0" } # AWS Nitro Enclaves NSM API - LOW RISK: Official AWS SDK
9696
aws-sdk-kms = { version = "=1.98.0" } # AWS KMS client - LOW RISK: Official AWS SDK for key management
9797
aws-sdk-s3 = { version = "=1.120.0" } # AWS S3 client - LOW RISK: Official AWS SDK for object storage
98-
aws-smithy-runtime = { version = "=1.10.3", features = ["client", "connector-hyper-0-14-x"] } # AWS Smithy runtime - LOW RISK: Official AWS runtime library
98+
aws-smithy-runtime = { version = "=1.10.3", features = ["client", "connector-hyper-0-14-x"] } # AWS Smithy runtime - LOW RISK: Official AWS runtime library. Remove "connector-hyper-0-14-x" when possible as it introduces a potential vulnerability.
9999
aws-smithy-runtime-api = { version = "=1.11.6" } # AWS Smithy runtime API - LOW RISK: Official AWS runtime API
100100
aws-smithy-types = { version = "=1.4.6" } # AWS Smithy types - LOW RISK: Official AWS type definitions
101101
axum = { version = "=0.8.8", features = ["tokio"] } # Web framework - LOW RISK: tokio-rs team, 168M+ downloads, actively maintained

core-client/src/crsgen.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub(crate) async fn do_crsgen(
3333
destination_prefix: &Path,
3434
context_id: Option<ContextId>,
3535
epoch_id: Option<EpochId>,
36+
extra_data: Vec<u8>,
3637
) -> anyhow::Result<RequestId> {
3738
let req_id = RequestId::new_random(rng);
3839

@@ -115,7 +116,7 @@ pub(crate) async fn do_crsgen(
115116
destination_prefix,
116117
req_id,
117118
domain,
118-
vec![],
119+
extra_data,
119120
resp_response_vec,
120121
cmd_conf.download_all,
121122
)

core-client/src/decrypt.rs

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ fn check_ext_pt_signature(
2626
external_handles: Vec<Vec<u8>>,
2727
domain: Eip712Domain,
2828
kms_addrs: &[alloy_primitives::Address],
29-
extra_data: Vec<u8>,
29+
extra_data: &[u8],
3030
) -> anyhow::Result<()> {
3131
tracing::debug!(
3232
"Checking signature for PTs: {:?}, ext. handles: {:?}, extra_data: {}, ext. sig {}",
3333
plaintexts,
3434
external_handles,
35-
hex::encode(&extra_data),
35+
hex::encode(extra_data),
3636
hex::encode(external_sig)
3737
);
3838
let message = compute_public_decryption_message(external_handles, plaintexts, extra_data)?;
@@ -54,6 +54,7 @@ fn check_external_decryption_signature(
5454
external_handles: &[Vec<u8>],
5555
domain: &Eip712Domain,
5656
kms_addrs: &[alloy_primitives::Address],
57+
extra_data: &[u8],
5758
) -> anyhow::Result<()> {
5859
let mut results = Vec::new();
5960
for response in responses {
@@ -67,7 +68,7 @@ fn check_external_decryption_signature(
6768
external_handles.to_owned(),
6869
domain.clone(),
6970
kms_addrs,
70-
vec![],
71+
extra_data,
7172
)?;
7273

7374
for (idx, pt) in payload.plaintexts.iter().enumerate() {
@@ -114,6 +115,7 @@ pub(crate) async fn do_public_decrypt<R: Rng + CryptoRng>(
114115
num_expected_responses: usize,
115116
inter_request_delay: tokio::time::Duration,
116117
parallel_requests: usize,
118+
extra_data: Vec<u8>,
117119
) -> anyhow::Result<Vec<(Option<RequestId>, String)>> {
118120
let mut timings_start = HashMap::new();
119121
let mut durations = Vec::new();
@@ -136,6 +138,7 @@ pub(crate) async fn do_public_decrypt<R: Rng + CryptoRng>(
136138
let core_endpoints_resp = core_endpoints_resp.clone();
137139
let ptxt = ptxt.clone();
138140
let kms_addrs = kms_addrs.clone();
141+
let extra_data = extra_data.clone();
139142

140143
// start timing measurement for this request
141144
timings_start.insert(req_id, tokio::time::Instant::now()); // start timing for this request
@@ -149,6 +152,7 @@ pub(crate) async fn do_public_decrypt<R: Rng + CryptoRng>(
149152
context_id.as_ref(),
150153
&key_id.into(),
151154
epoch_id.as_ref(),
155+
&extra_data,
152156
)?;
153157

154158
// make parallel requests by calling [decrypt] in a thread
@@ -246,6 +250,7 @@ pub(crate) async fn do_user_decrypt<R: Rng + CryptoRng>(
246250
num_expected_responses: usize,
247251
inter_request_delay: tokio::time::Duration,
248252
parallel_requests: usize,
253+
extra_data: Vec<u8>,
249254
) -> anyhow::Result<Vec<(Option<RequestId>, String)>> {
250255
let mut join_set: JoinSet<Result<_, anyhow::Error>> = JoinSet::new();
251256
let mut timings_start = HashMap::new();
@@ -267,6 +272,7 @@ pub(crate) async fn do_user_decrypt<R: Rng + CryptoRng>(
267272
let core_endpoints_req = core_endpoints_req.clone();
268273
let core_endpoints_resp = core_endpoints_resp.clone();
269274
let original_plaintext = ptxt.clone();
275+
let extra_data = extra_data.clone();
270276

271277
// start timing measurement for this request
272278
timings_start.insert(req_id, tokio::time::Instant::now()); // start timing for this request
@@ -281,6 +287,7 @@ pub(crate) async fn do_user_decrypt<R: Rng + CryptoRng>(
281287
context_id.as_ref(),
282288
epoch_id.as_ref(),
283289
PkeSchemeType::MlKem512,
290+
&extra_data,
284291
)?;
285292

286293
let (user_decrypt_req, enc_pk, enc_sk) = user_decrypt_req_tuple;
@@ -587,7 +594,8 @@ pub(crate) async fn get_public_decrypt_responses(
587594
.clone(),
588595
};
589596

590-
let (domain, external_handles) = if let Some(decryption_request) = dec_req.as_ref() {
597+
let (domain, external_handles, extra_data) = if let Some(decryption_request) = dec_req.as_ref()
598+
{
591599
let domain_msg = decryption_request
592600
.domain
593601
.as_ref()
@@ -599,7 +607,8 @@ pub(crate) async fn get_public_decrypt_responses(
599607
.iter()
600608
.map(|ct| ct.external_handle.clone())
601609
.collect();
602-
(domain, external_handles)
610+
let extra_data = decryption_request.extra_data.clone();
611+
(domain, external_handles, extra_data)
603612
} else {
604613
//If the decryption request isn't provided we assume it was dummy domains and handles
605614
let num_handles = resp_response_vec
@@ -610,7 +619,16 @@ pub(crate) async fn get_public_decrypt_responses(
610619
.ok_or_else(|| anyhow::anyhow!("missing payload in first decryption response"))?
611620
.plaintexts
612621
.len();
613-
(dummy_domain(), vec![dummy_handle(); num_handles])
622+
let extra_data = resp_response_vec
623+
.first()
624+
.ok_or_else(|| anyhow::anyhow!("no public decryption responses available"))?
625+
.extra_data
626+
.clone();
627+
(
628+
dummy_domain(),
629+
vec![dummy_handle(); num_handles],
630+
extra_data,
631+
)
614632
};
615633

616634
// check the internal signatures
@@ -627,6 +645,7 @@ pub(crate) async fn get_public_decrypt_responses(
627645
&external_handles,
628646
&domain,
629647
kms_addrs,
648+
&extra_data,
630649
)?;
631650

632651
tracing::info!(

core-client/src/keygen.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ pub(crate) async fn do_keygen(
6565
insecure: bool,
6666
shared_config: &SharedKeyGenParameters,
6767
destination_prefix: &Path,
68+
extra_data: Vec<u8>,
6869
) -> anyhow::Result<RequestId> {
6970
let req_id = RequestId::new_random(rng);
7071

@@ -162,7 +163,7 @@ pub(crate) async fn do_keygen(
162163
destination_prefix,
163164
req_id,
164165
domain,
165-
vec![],
166+
extra_data,
166167
resp_response_vec,
167168
cmd_conf.download_all,
168169
shared_config.compressed,

core-client/src/lib.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,20 @@ impl CipherArguments {
494494
}
495495
}
496496
}
497+
498+
pub fn get_extra_data(&self) -> Vec<u8> {
499+
let hex_str = match self {
500+
CipherArguments::FromFile(cipher_file) => &cipher_file.extra_data,
501+
CipherArguments::FromArgs(cipher_parameters) => &cipher_parameters.extra_data,
502+
};
503+
parse_extra_data(hex_str)
504+
}
505+
}
506+
507+
// Helper function to parse the extra data from the CLI arguments, with the same logic for both CipherParameters and CipherFile.
508+
// Defaults to an empty byte vector if the extra data is not provided or if the hex parsing fails.
509+
fn parse_extra_data(hex_str: &Option<String>) -> Vec<u8> {
510+
parse_hex(hex_str.as_deref().unwrap_or("")).unwrap_or_default()
497511
}
498512

499513
#[derive(Debug, Args, Clone, Serialize, Deserialize)]
@@ -553,6 +567,11 @@ pub struct CipherParameters {
553567
#[serde(skip_serializing, skip_deserializing)]
554568
#[clap(long, default_value_t = false)]
555569
pub compressed_keys: bool,
570+
/// Optional extra data (hex-encoded) to include in the request.
571+
/// Can optionally have a "0x" prefix.
572+
#[serde(skip_serializing, skip_deserializing)]
573+
#[clap(long)]
574+
pub extra_data: Option<String>,
556575
}
557576

558577
#[derive(Debug, Args, Clone)]
@@ -573,6 +592,10 @@ pub struct CipherFile {
573592
/// Number of requests to be sent in parallel (at most num_requests) before waiting for inter_request_delay_ms.
574593
#[clap(long, short = 'p', default_value_t = 0)]
575594
pub parallel_requests: usize,
595+
/// Optional extra data (hex-encoded) to include in the request.
596+
/// Can optionally have a "0x" prefix.
597+
#[clap(long)]
598+
pub extra_data: Option<String>,
576599
}
577600

578601
#[derive(Debug, Serialize, Deserialize)]
@@ -601,6 +624,10 @@ pub struct SharedKeyGenParameters {
601624
pub use_existing_key_tag: bool,
602625
pub context_id: Option<ContextId>,
603626
pub epoch_id: Option<EpochId>,
627+
/// Optional extra data (hex-encoded) to include in the request.
628+
/// Can optionally have a "0x" prefix.
629+
#[clap(long)]
630+
pub extra_data: Option<String>,
604631
}
605632

606633
#[derive(Debug, Parser, Clone)]
@@ -626,6 +653,10 @@ pub struct CrsParameters {
626653
pub epoch_id: Option<EpochId>,
627654
#[clap(long)]
628655
pub context_id: Option<ContextId>,
656+
/// Optional extra data (hex-encoded) to include in the request.
657+
/// Can optionally have a "0x" prefix.
658+
#[clap(long)]
659+
pub extra_data: Option<String>,
629660
}
630661

631662
impl Default for CrsParameters {
@@ -634,6 +665,7 @@ impl Default for CrsParameters {
634665
max_num_bits: 2048,
635666
epoch_id: None,
636667
context_id: None,
668+
extra_data: None,
637669
}
638670
}
639671
}
@@ -1667,6 +1699,7 @@ pub async fn execute_cmd(
16671699
num_expected_responses,
16681700
cipher_args.get_inter_request_delay_ms(),
16691701
cipher_args.get_parallel_requests(),
1702+
cipher_args.get_extra_data(),
16701703
)
16711704
.await?
16721705
}
@@ -1742,6 +1775,7 @@ pub async fn execute_cmd(
17421775
num_expected_responses,
17431776
cipher_args.get_inter_request_delay_ms(),
17441777
cipher_args.get_parallel_requests(),
1778+
cipher_args.get_extra_data(),
17451779
)
17461780
.await?
17471781
}
@@ -1767,6 +1801,7 @@ pub async fn execute_cmd(
17671801
false,
17681802
shared_args,
17691803
destination_prefix,
1804+
parse_extra_data(&shared_args.extra_data),
17701805
)
17711806
.await?;
17721807

@@ -1792,6 +1827,7 @@ pub async fn execute_cmd(
17921827
true,
17931828
shared_args,
17941829
destination_prefix,
1830+
parse_extra_data(&shared_args.extra_data),
17951831
)
17961832
.await?;
17971833

@@ -1801,6 +1837,7 @@ pub async fn execute_cmd(
18011837
max_num_bits,
18021838
epoch_id,
18031839
context_id,
1840+
extra_data,
18041841
}) => {
18051842
let mut internal_client = internal_client.unwrap();
18061843
tracing::info!(
@@ -1822,6 +1859,7 @@ pub async fn execute_cmd(
18221859
destination_prefix,
18231860
*context_id,
18241861
*epoch_id,
1862+
parse_extra_data(extra_data),
18251863
)
18261864
.await?;
18271865
vec![(Some(req_id), "crsgen done".to_string())]
@@ -1830,6 +1868,7 @@ pub async fn execute_cmd(
18301868
max_num_bits,
18311869
epoch_id,
18321870
context_id,
1871+
extra_data,
18331872
}) => {
18341873
let mut internal_client = internal_client.unwrap();
18351874
tracing::info!(
@@ -1851,6 +1890,7 @@ pub async fn execute_cmd(
18511890
destination_prefix,
18521891
*context_id,
18531892
*epoch_id,
1893+
parse_extra_data(extra_data),
18541894
)
18551895
.await?;
18561896
vec![(Some(req_id), "insecure crsgen done".to_string())]

0 commit comments

Comments
 (0)