Skip to content

Commit dba645a

Browse files
committed
Support public key literals and tidy up handling of Raw vs PodId
1 parent b604150 commit dba645a

File tree

6 files changed

+183
-21
lines changed

6 files changed

+183
-21
lines changed

src/backends/plonky2/primitives/ec/curve.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use core::ops::{Add, Mul};
55
use std::{
66
array, fmt,
77
ops::{AddAssign, Neg, Sub},
8+
str::FromStr,
89
sync::LazyLock,
910
};
1011

@@ -121,6 +122,27 @@ impl fmt::Display for Point {
121122
}
122123
}
123124

125+
impl FromStr for Point {
126+
type Err = Error;
127+
128+
fn from_str(s: &str) -> Result<Self, Self::Err> {
129+
let point_bytes = bs58::decode(s)
130+
.into_vec()
131+
.map_err(|e| Error::custom(format!("Base58 decode error: {}", e)))?;
132+
133+
if point_bytes.len() == 80 {
134+
// Non-compressed
135+
Ok(Point {
136+
x: ec_field_from_bytes(&point_bytes[..40])?,
137+
u: ec_field_from_bytes(&point_bytes[40..])?,
138+
})
139+
} else {
140+
// Compressed
141+
Self::from_bytes_into_subgroup(&point_bytes)
142+
}
143+
}
144+
}
145+
124146
impl Serialize for Point {
125147
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
126148
where

src/lang/grammar.pest

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI }
3232
use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref }
3333
use_predicate_list = { import_name ~ ("," ~ import_name)* }
3434
import_name = { identifier | "_" }
35-
batch_ref = { literal_raw }
35+
batch_ref = { hash_hex }
3636

3737
request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" }
3838

@@ -59,11 +59,13 @@ anchored_key = { wildcard ~ "[" ~ literal_string ~ "]" }
5959

6060
// Literal Values (ordered to avoid ambiguity, e.g., string before int)
6161
literal_value = {
62+
literal_public_key |
6263
literal_dict |
6364
literal_set |
6465
literal_array |
6566
literal_bool |
6667
literal_raw |
68+
literal_pod_id |
6769
literal_string |
6870
literal_int
6971
}
@@ -72,9 +74,12 @@ literal_value = {
7274
literal_int = @{ "-"? ~ ASCII_DIGIT+ }
7375
literal_bool = @{ "true" | "false" }
7476

75-
// literal_raw: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters)
77+
// hash_hex: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters)
7678
// representing a 32-byte value in big-endian order
77-
literal_raw = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} }
79+
hash_hex = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} }
80+
81+
literal_raw = { "Raw" ~ "(" ~ hash_hex ~ ")" }
82+
literal_pod_id = { hash_hex }
7883

7984
// String literal parsing based on https://pest.rs/book/examples/json.html
8085
literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule
@@ -85,6 +90,11 @@ char = { // Rule for a single logical character (unescaped or escaped)
8590
| "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence
8691
}
8792

93+
// PublicKey(...)
94+
base58_char = { '1'..'9' | 'A'..'H' | 'J'..'N' | 'P'..'Z' | 'a'..'k' | 'm'..'z' }
95+
base58_string = @{ base58_char+ }
96+
literal_public_key = { "PublicKey" ~ "(" ~ base58_string ~ ")" }
97+
8898
// Container Literals (recursive definition using literal_value)
8999
literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
90100
literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
@@ -95,7 +105,9 @@ dict_pair = { literal_string ~ ":" ~ literal_value }
95105
test_identifier = { SOI ~ identifier ~ EOI }
96106
test_wildcard = { SOI ~ wildcard ~ EOI }
97107
test_literal_int = { SOI ~ literal_int ~ EOI }
98-
test_literal_raw = { SOI ~ literal_raw ~ EOI }
108+
test_hash_hex = { SOI ~ hash_hex ~ EOI }
109+
test_literal_raw = { SOI ~ literal_raw ~ EOI }
110+
test_literal_pod_id = { SOI ~ literal_pod_id ~ EOI }
99111
test_literal_value = { SOI ~ literal_value ~ EOI }
100112
test_statement = { SOI ~ statement ~ EOI }
101113
test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI }

src/lang/mod.rs

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ mod tests {
2929
use crate::{
3030
lang::error::ProcessorError,
3131
middleware::{
32-
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
33-
Params, PodType, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard,
34-
SELF_ID_HASH,
32+
hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
33+
NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
34+
StatementTmplArg, Value, Wildcard, SELF_ID_HASH,
3535
},
3636
};
3737

@@ -854,6 +854,78 @@ mod tests {
854854
Ok(())
855855
}
856856

857+
#[test]
858+
fn test_e2e_literals() -> Result<(), LangError> {
859+
let pk = crate::backends::plonky2::primitives::ec::curve::Point::generator();
860+
let pk_b58 = pk.to_string();
861+
let pod_id = PodId(hash_str("test"));
862+
let raw = RawValue::from(1);
863+
let string = "hello";
864+
let int = 123;
865+
let bool = true;
866+
867+
let input = format!(
868+
r#"
869+
REQUEST(
870+
Equal(?A["pk"], PublicKey({}))
871+
Equal(?B["pod_id"], {})
872+
Equal(?C["raw"], Raw({}))
873+
Equal(?D["string"], "{}")
874+
Equal(?E["int"], {})
875+
Equal(?F["bool"], {})
876+
)
877+
"#,
878+
pk_b58, pod_id, raw, string, int, bool
879+
);
880+
/*
881+
REQUEST(
882+
Equal(?A["pk"], PublicKey(3t9fNuU194n7mSJPRdeaJRMqw6ZQCUddzvECWNe1k2b1rdBezXpJxF))
883+
Equal(?B["pod_id"], 0x735b31d3aad0f5b66002ffe1dc7d2eaa0ee9c59c09b641e8261530c5f3a02f29)
884+
Equal(?C["raw"], Raw(0x0000000000000000000000000000000000000000000000000000000000000001))
885+
Equal(?D["string"], "hello")
886+
Equal(?E["int"], 123)
887+
Equal(?F["bool"], true)
888+
)
889+
*/
890+
891+
let params = Params::default();
892+
let processed = parse(&input, &params, &[])?;
893+
let request_templates = processed.request_templates;
894+
895+
assert_eq!(request_templates.len(), 6);
896+
897+
let expected_templates = vec![
898+
StatementTmpl {
899+
pred: Predicate::Native(NativePredicate::Equal),
900+
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
901+
},
902+
StatementTmpl {
903+
pred: Predicate::Native(NativePredicate::Equal),
904+
args: vec![sta_ak(("B", 1), "pod_id"), sta_lit(Value::from(pod_id))],
905+
},
906+
StatementTmpl {
907+
pred: Predicate::Native(NativePredicate::Equal),
908+
args: vec![sta_ak(("C", 2), "raw"), sta_lit(Value::from(raw))],
909+
},
910+
StatementTmpl {
911+
pred: Predicate::Native(NativePredicate::Equal),
912+
args: vec![sta_ak(("D", 3), "string"), sta_lit(Value::from(string))],
913+
},
914+
StatementTmpl {
915+
pred: Predicate::Native(NativePredicate::Equal),
916+
args: vec![sta_ak(("E", 4), "int"), sta_lit(Value::from(int))],
917+
},
918+
StatementTmpl {
919+
pred: Predicate::Native(NativePredicate::Equal),
920+
args: vec![sta_ak(("F", 5), "bool"), sta_lit(Value::from(bool))],
921+
},
922+
];
923+
924+
assert_eq!(request_templates, expected_templates);
925+
926+
Ok(())
927+
}
928+
857929
#[test]
858930
fn test_e2e_use_unknown_batch() {
859931
let params = Params::default();

src/lang/parser.rs

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,37 +106,67 @@ mod tests {
106106
// Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements)
107107
assert_parses(
108108
Rule::literal_raw,
109-
"0x0000000000000000000000000000000000000000000000000000000000000000",
109+
"Raw(0x0000000000000000000000000000000000000000000000000000000000000000)",
110110
);
111111
assert_parses(
112112
Rule::literal_raw,
113-
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd",
113+
"Raw(0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd)",
114114
);
115-
let long_valid_raw = format!("0x{}", "a".repeat(64));
115+
let long_valid_raw = format!("Raw(0x{})", "a".repeat(64));
116116
assert_parses(Rule::literal_raw, &long_valid_raw);
117117

118118
// Use anchored rule for failure cases
119-
assert_fails(Rule::test_literal_raw, "0xabc"); // Fails (string is too short)
120-
assert_fails(Rule::test_literal_raw, "0x"); // Fails (needs at least one pair)
121-
assert_fails(Rule::test_literal_raw, &format!("0x{}", "a".repeat(66))); // Fails (string is too long)
119+
assert_fails(
120+
Rule::test_literal_raw,
121+
"0x0000000000000000000000000000000000000000000000000000000000000000)",
122+
); // Missing Raw() wrapper
123+
assert_fails(Rule::test_literal_raw, "Raw(0xabc)"); // Fails (string is too short)
124+
assert_fails(Rule::test_literal_raw, "Raw(0x)"); // Fails (needs at least one pair)
125+
assert_fails(
126+
Rule::test_literal_raw,
127+
&format!("Raw(0x{})", "a".repeat(66)),
128+
); // Fails (string is too long)
129+
130+
// PodId (essentially identical to Raw but without the wrapper)
131+
assert_parses(
132+
Rule::literal_pod_id,
133+
"0x0000000000000000000000000000000000000000000000000000000000000000",
134+
);
135+
assert_parses(
136+
Rule::literal_pod_id,
137+
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd",
138+
);
139+
let long_valid_pod_id = format!("0x{}", "a".repeat(64));
140+
assert_parses(Rule::literal_pod_id, &long_valid_pod_id);
141+
142+
assert_fails(Rule::test_literal_pod_id, "0xabc"); // Fails (string is too short)
143+
assert_fails(Rule::test_literal_pod_id, "0x"); // Fails (needs at least one pair)
144+
assert_fails(Rule::test_literal_pod_id, &format!("0x{}", "a".repeat(66))); // Fails (string is too long)
122145

123146
// String
124147
assert_parses(Rule::literal_string, "\"hello\"");
125148
assert_parses(Rule::literal_string, "\"escaped \\\" quote\"");
126149
assert_parses(Rule::literal_string, "\"\\\\ backslash\"");
127150
assert_parses(Rule::literal_string, "\"\\uABCD\"");
128151
assert_fails(Rule::literal_string, "\"unterminated");
152+
153+
// PublicKey
154+
assert_parses(Rule::literal_public_key, "PublicKey(base58string)");
155+
assert_fails(Rule::literal_public_key, "PublicKey(OhNo)"); // Fails because O is not valid base58
156+
129157
// Array
130158
assert_parses(Rule::literal_array, "[]");
131159
assert_parses(Rule::literal_array, "[1, \"two\", true]");
132160
assert_parses(Rule::literal_array, "[ [1], #[2] ]");
161+
133162
// Set
134163
assert_parses(Rule::literal_set, "#[]");
135164
assert_parses(Rule::literal_set, "#[1, 2, 3]");
136165
assert_parses(
137166
Rule::literal_set,
138167
"#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]",
139168
);
169+
140170
// Dict
141171
assert_parses(Rule::literal_dict, "{}");
142172
assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }");

src/lang/processor.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use plonky2::field::types::Field;
88

99
use super::error::ProcessorError;
1010
use crate::{
11+
backends::plonky2::primitives::ec::curve::Point,
1112
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
1213
lang::parser::Rule,
1314
middleware::{
@@ -335,10 +336,10 @@ fn validate_and_build_statement_template(
335336
| NativePredicate::Lt
336337
| NativePredicate::LtEq
337338
| NativePredicate::SetContains
339+
| NativePredicate::NotContains
338340
| NativePredicate::DictNotContains
339341
| NativePredicate::SetNotContains => 2,
340-
NativePredicate::NotContains
341-
| NativePredicate::Contains
342+
NativePredicate::Contains
342343
| NativePredicate::ArrayContains
343344
| NativePredicate::DictContains
344345
| NativePredicate::SumOf
@@ -523,7 +524,6 @@ fn process_and_add_custom_predicate_to_batch(
523524
let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect();
524525
let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect();
525526
let sts_slice: &[StatementTmplBuilder] = &statement_builders;
526-
527527
if conjunction {
528528
cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?;
529529
} else {
@@ -667,11 +667,11 @@ fn process_literal_value(
667667
Ok(Value::from(val))
668668
}
669669
Rule::literal_raw => {
670-
let full_literal_str = inner_lit.as_str();
670+
let full_literal_str = inner_lit.clone().into_inner().next().unwrap();
671671
let hex_str_no_prefix = full_literal_str
672+
.as_str()
672673
.strip_prefix("0x")
673-
.unwrap_or(full_literal_str);
674-
674+
.unwrap_or(full_literal_str.as_str());
675675
parse_hex_str_to_raw_value(hex_str_no_prefix)
676676
.map_err(|e| match e {
677677
ProcessorError::InvalidLiteralFormat { kind, value, .. } => {
@@ -694,6 +694,27 @@ fn process_literal_value(
694694
})
695695
.map(Value::from)
696696
}
697+
Rule::literal_pod_id => {
698+
let hex_str_no_prefix = inner_lit
699+
.as_str()
700+
.strip_prefix("0x")
701+
.unwrap_or(inner_lit.as_str());
702+
let pod_id = parse_hex_str_to_pod_id(hex_str_no_prefix)?;
703+
Ok(Value::from(pod_id))
704+
}
705+
Rule::literal_public_key => {
706+
let pk_str_pair = inner_lit.into_inner().next().unwrap();
707+
let pk_b58 = pk_str_pair.as_str();
708+
let point: Point =
709+
pk_b58
710+
.parse()
711+
.map_err(|e| ProcessorError::InvalidLiteralFormat {
712+
kind: "PublicKey".to_string(),
713+
value: format!("{} (error: {})", pk_b58, e),
714+
span: Some(get_span(&pk_str_pair)),
715+
})?;
716+
Ok(Value::from(point))
717+
}
697718
Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)),
698719
Rule::literal_array => {
699720
let elements: Result<Vec<Value>, ProcessorError> = inner_lit
@@ -823,6 +844,11 @@ fn parse_hex_str_to_raw_value(hex_str: &str) -> Result<middleware::RawValue, Pro
823844
Ok(middleware::RawValue(v))
824845
}
825846

847+
fn parse_hex_str_to_pod_id(hex_str: &str) -> Result<middleware::PodId, ProcessorError> {
848+
let raw = parse_hex_str_to_raw_value(hex_str)?;
849+
Ok(middleware::PodId(raw.into()))
850+
}
851+
826852
// Helper to resolve a wildcard name string to an indexed middleware::Wildcard
827853
// based on an ordered list of names from the current scope (e.g., request or predicate def).
828854
fn resolve_wildcard(

src/middleware/basetypes.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,15 @@ impl Ord for Hash {
200200
impl fmt::Display for Hash {
201201
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202202
if f.alternate() {
203-
write!(f, "0x{}", self.encode_hex::<String>())
204-
} else {
205203
// display first hex digit in big endian
206204
write!(f, "0x")?;
207205
let v3 = self.0[3].to_canonical_u64();
208206
for i in 0..4 {
209207
write!(f, "{:02x}", (v3 >> ((7 - i) * 8)) & 0xff)?;
210208
}
211209
write!(f, "…")
210+
} else {
211+
write!(f, "0x{}", self.encode_hex::<String>())
212212
}
213213
}
214214
}

0 commit comments

Comments
 (0)