Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/backends/plonky2/primitives/ec/curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::ops::{Add, Mul};
use std::{
array, fmt,
ops::{AddAssign, Neg, Sub},
str::FromStr,
sync::LazyLock,
};

Expand Down Expand Up @@ -121,6 +122,27 @@ impl fmt::Display for Point {
}
}

impl FromStr for Point {
type Err = Error;

fn from_str(s: &str) -> Result<Self, Self::Err> {
let point_bytes = bs58::decode(s)
.into_vec()
.map_err(|e| Error::custom(format!("Base58 decode error: {}", e)))?;

if point_bytes.len() == 80 {
// Non-compressed
Ok(Point {
x: ec_field_from_bytes(&point_bytes[..40])?,
u: ec_field_from_bytes(&point_bytes[40..])?,
})
} else {
// Compressed
Self::from_bytes_into_subgroup(&point_bytes)
}
}
}

impl Serialize for Point {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand All @@ -137,19 +159,7 @@ impl<'de> Deserialize<'de> for Point {
D: Deserializer<'de>,
{
let point_b58 = String::deserialize(deserializer)?;
let point_bytes: Vec<u8> = bs58::decode(point_b58)
.into_vec()
.map_err(serde::de::Error::custom)?;
if point_bytes.len() == 80 {
// Non-compressed
Ok(Point {
x: ec_field_from_bytes(&point_bytes[..40]).map_err(serde::de::Error::custom)?,
u: ec_field_from_bytes(&point_bytes[40..]).map_err(serde::de::Error::custom)?,
})
} else {
// Compressed
Self::from_bytes_into_subgroup(&point_bytes).map_err(serde::de::Error::custom)
}
Self::from_str(&point_b58).map_err(serde::de::Error::custom)
}
}

Expand Down
20 changes: 16 additions & 4 deletions src/lang/grammar.pest
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ document = { SOI ~ (use_statement | custom_predicate_def | request_def)* ~ EOI }
use_statement = { "use" ~ use_predicate_list ~ "from" ~ batch_ref }
use_predicate_list = { import_name ~ ("," ~ import_name)* }
import_name = { identifier | "_" }
batch_ref = { literal_raw }
batch_ref = { hash_hex }

request_def = { "REQUEST" ~ "(" ~ statement_list? ~ ")" }

Expand All @@ -59,11 +59,13 @@ anchored_key = { wildcard ~ "[" ~ literal_string ~ "]" }

// Literal Values (ordered to avoid ambiguity, e.g., string before int)
literal_value = {
literal_public_key |
literal_dict |
literal_set |
literal_array |
literal_bool |
literal_raw |
literal_pod_id |
literal_string |
literal_int
}
Expand All @@ -72,9 +74,12 @@ literal_value = {
literal_int = @{ "-"? ~ ASCII_DIGIT+ }
literal_bool = @{ "true" | "false" }

// literal_raw: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters)
// hash_hex: 0x followed by exactly 32 PAIRS of hex digits (64 hex characters)
// representing a 32-byte value in big-endian order
literal_raw = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} }
hash_hex = @{ "0x" ~ (ASCII_HEX_DIGIT ~ ASCII_HEX_DIGIT){32} }

literal_raw = { "Raw" ~ "(" ~ hash_hex ~ ")" }
literal_pod_id = { hash_hex }

// String literal parsing based on https://pest.rs/book/examples/json.html
literal_string = ${ "\"" ~ inner ~ "\"" } // Compound atomic string rule
Expand All @@ -85,6 +90,11 @@ char = { // Rule for a single logical character (unescaped or escaped)
| "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4}) // Unicode escape sequence
}

// PublicKey(...)
base58_char = { '1'..'9' | 'A'..'H' | 'J'..'N' | 'P'..'Z' | 'a'..'k' | 'm'..'z' }
base58_string = @{ base58_char+ }
literal_public_key = { "PublicKey" ~ "(" ~ base58_string ~ ")" }

// Container Literals (recursive definition using literal_value)
literal_array = { "[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
literal_set = { "#[" ~ (literal_value ~ ("," ~ literal_value)*)? ~ "]" }
Expand All @@ -95,7 +105,9 @@ dict_pair = { literal_string ~ ":" ~ literal_value }
test_identifier = { SOI ~ identifier ~ EOI }
test_wildcard = { SOI ~ wildcard ~ EOI }
test_literal_int = { SOI ~ literal_int ~ EOI }
test_literal_raw = { SOI ~ literal_raw ~ EOI }
test_hash_hex = { SOI ~ hash_hex ~ EOI }
test_literal_raw = { SOI ~ literal_raw ~ EOI }
test_literal_pod_id = { SOI ~ literal_pod_id ~ EOI }
test_literal_value = { SOI ~ literal_value ~ EOI }
test_statement = { SOI ~ statement ~ EOI }
test_custom_predicate_def = { SOI ~ custom_predicate_def ~ EOI }
78 changes: 75 additions & 3 deletions src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ mod tests {
use crate::{
lang::error::ProcessorError,
middleware::{
CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key, NativePredicate,
Params, PodType, Predicate, StatementTmpl, StatementTmplArg, Value, Wildcard,
SELF_ID_HASH,
hash_str, CustomPredicate, CustomPredicateBatch, CustomPredicateRef, Key,
NativePredicate, Params, PodId, PodType, Predicate, RawValue, StatementTmpl,
StatementTmplArg, Value, Wildcard, SELF_ID_HASH,
},
};

Expand Down Expand Up @@ -854,6 +854,78 @@ mod tests {
Ok(())
}

#[test]
fn test_e2e_literals() -> Result<(), LangError> {
let pk = crate::backends::plonky2::primitives::ec::curve::Point::generator();
let pk_b58 = pk.to_string();
let pod_id = PodId(hash_str("test"));
let raw = RawValue::from(1);
let string = "hello";
let int = 123;
let bool = true;

let input = format!(
r#"
REQUEST(
Equal(?A["pk"], PublicKey({}))
Equal(?B["pod_id"], {:#})
Equal(?C["raw"], Raw({:#}))
Equal(?D["string"], "{}")
Equal(?E["int"], {})
Equal(?F["bool"], {})
)
"#,
pk_b58, pod_id, raw, string, int, bool
);
/*
REQUEST(
Equal(?A["pk"], PublicKey(3t9fNuU194n7mSJPRdeaJRMqw6ZQCUddzvECWNe1k2b1rdBezXpJxF))
Equal(?B["pod_id"], 0x735b31d3aad0f5b66002ffe1dc7d2eaa0ee9c59c09b641e8261530c5f3a02f29)
Equal(?C["raw"], Raw(0x0000000000000000000000000000000000000000000000000000000000000001))
Equal(?D["string"], "hello")
Equal(?E["int"], 123)
Equal(?F["bool"], true)
)
*/

let params = Params::default();
let processed = parse(&input, &params, &[])?;
let request_templates = processed.request_templates;

assert_eq!(request_templates.len(), 6);

let expected_templates = vec![
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("A", 0), "pk"), sta_lit(Value::from(pk))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("B", 1), "pod_id"), sta_lit(Value::from(pod_id))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("C", 2), "raw"), sta_lit(Value::from(raw))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("D", 3), "string"), sta_lit(Value::from(string))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("E", 4), "int"), sta_lit(Value::from(int))],
},
StatementTmpl {
pred: Predicate::Native(NativePredicate::Equal),
args: vec![sta_ak(("F", 5), "bool"), sta_lit(Value::from(bool))],
},
];

assert_eq!(request_templates, expected_templates);

Ok(())
}

#[test]
fn test_e2e_use_unknown_batch() {
let params = Params::default();
Expand Down
42 changes: 36 additions & 6 deletions src/lang/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,37 +106,67 @@ mod tests {
// Raw - Require 64 hex digits (32 bytes, equal to 4 * 64-bit field elements)
assert_parses(
Rule::literal_raw,
"0x0000000000000000000000000000000000000000000000000000000000000000",
"Raw(0x0000000000000000000000000000000000000000000000000000000000000000)",
);
assert_parses(
Rule::literal_raw,
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd",
"Raw(0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd)",
);
let long_valid_raw = format!("0x{}", "a".repeat(64));
let long_valid_raw = format!("Raw(0x{})", "a".repeat(64));
assert_parses(Rule::literal_raw, &long_valid_raw);

// Use anchored rule for failure cases
assert_fails(Rule::test_literal_raw, "0xabc"); // Fails (string is too short)
assert_fails(Rule::test_literal_raw, "0x"); // Fails (needs at least one pair)
assert_fails(Rule::test_literal_raw, &format!("0x{}", "a".repeat(66))); // Fails (string is too long)
assert_fails(
Rule::test_literal_raw,
"0x0000000000000000000000000000000000000000000000000000000000000000)",
); // Missing Raw() wrapper
assert_fails(Rule::test_literal_raw, "Raw(0xabc)"); // Fails (string is too short)
assert_fails(Rule::test_literal_raw, "Raw(0x)"); // Fails (needs at least one pair)
assert_fails(
Rule::test_literal_raw,
&format!("Raw(0x{})", "a".repeat(66)),
); // Fails (string is too long)

// PodId (essentially identical to Raw but without the wrapper)
assert_parses(
Rule::literal_pod_id,
"0x0000000000000000000000000000000000000000000000000000000000000000",
);
assert_parses(
Rule::literal_pod_id,
"0xabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcdabcd",
);
let long_valid_pod_id = format!("0x{}", "a".repeat(64));
assert_parses(Rule::literal_pod_id, &long_valid_pod_id);

assert_fails(Rule::test_literal_pod_id, "0xabc"); // Fails (string is too short)
assert_fails(Rule::test_literal_pod_id, "0x"); // Fails (needs at least one pair)
assert_fails(Rule::test_literal_pod_id, &format!("0x{}", "a".repeat(66))); // Fails (string is too long)

// String
assert_parses(Rule::literal_string, "\"hello\"");
assert_parses(Rule::literal_string, "\"escaped \\\" quote\"");
assert_parses(Rule::literal_string, "\"\\\\ backslash\"");
assert_parses(Rule::literal_string, "\"\\uABCD\"");
assert_fails(Rule::literal_string, "\"unterminated");

// PublicKey
assert_parses(Rule::literal_public_key, "PublicKey(base58string)");
assert_fails(Rule::literal_public_key, "PublicKey(OhNo)"); // Fails because O is not valid base58

// Array
assert_parses(Rule::literal_array, "[]");
assert_parses(Rule::literal_array, "[1, \"two\", true]");
assert_parses(Rule::literal_array, "[ [1], #[2] ]");

// Set
assert_parses(Rule::literal_set, "#[]");
assert_parses(Rule::literal_set, "#[1, 2, 3]");
assert_parses(
Rule::literal_set,
"#[ \"a\", 0x0000000000000000000000000000000000000000000000000000000000000000 ]",
);

// Dict
assert_parses(Rule::literal_dict, "{}");
assert_parses(Rule::literal_dict, "{ \"name\": \"Alice\", \"age\": 30 }");
Expand Down
38 changes: 32 additions & 6 deletions src/lang/processor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use plonky2::field::types::Field;

use super::error::ProcessorError;
use crate::{
backends::plonky2::primitives::ec::curve::Point,
frontend::{BuilderArg, CustomPredicateBatchBuilder, StatementTmplBuilder},
lang::parser::Rule,
middleware::{
Expand Down Expand Up @@ -335,10 +336,10 @@ fn validate_and_build_statement_template(
| NativePredicate::Lt
| NativePredicate::LtEq
| NativePredicate::SetContains
| NativePredicate::NotContains
| NativePredicate::DictNotContains
| NativePredicate::SetNotContains => 2,
NativePredicate::NotContains
| NativePredicate::Contains
NativePredicate::Contains
| NativePredicate::ArrayContains
| NativePredicate::DictContains
| NativePredicate::SumOf
Expand Down Expand Up @@ -523,7 +524,6 @@ fn process_and_add_custom_predicate_to_batch(
let public_args_strs: Vec<&str> = public_arg_strings.iter().map(AsRef::as_ref).collect();
let private_args_strs: Vec<&str> = private_arg_strings.iter().map(AsRef::as_ref).collect();
let sts_slice: &[StatementTmplBuilder] = &statement_builders;

if conjunction {
cpb_builder.predicate_and(&name, &public_args_strs, &private_args_strs, sts_slice)?;
} else {
Expand Down Expand Up @@ -667,11 +667,11 @@ fn process_literal_value(
Ok(Value::from(val))
}
Rule::literal_raw => {
let full_literal_str = inner_lit.as_str();
let full_literal_str = inner_lit.clone().into_inner().next().unwrap();
let hex_str_no_prefix = full_literal_str
.as_str()
.strip_prefix("0x")
.unwrap_or(full_literal_str);

.unwrap_or(full_literal_str.as_str());
parse_hex_str_to_raw_value(hex_str_no_prefix)
.map_err(|e| match e {
ProcessorError::InvalidLiteralFormat { kind, value, .. } => {
Expand All @@ -694,6 +694,27 @@ fn process_literal_value(
})
.map(Value::from)
}
Rule::literal_pod_id => {
let hex_str_no_prefix = inner_lit
.as_str()
.strip_prefix("0x")
.unwrap_or(inner_lit.as_str());
let pod_id = parse_hex_str_to_pod_id(hex_str_no_prefix)?;
Ok(Value::from(pod_id))
}
Rule::literal_public_key => {
let pk_str_pair = inner_lit.into_inner().next().unwrap();
let pk_b58 = pk_str_pair.as_str();
let point: Point =
pk_b58
.parse()
.map_err(|e| ProcessorError::InvalidLiteralFormat {
kind: "PublicKey".to_string(),
value: format!("{} (error: {})", pk_b58, e),
span: Some(get_span(&pk_str_pair)),
})?;
Ok(Value::from(point))
}
Rule::literal_string => Ok(Value::from(parse_pest_string_literal(&inner_lit)?)),
Rule::literal_array => {
let elements: Result<Vec<Value>, ProcessorError> = inner_lit
Expand Down Expand Up @@ -823,6 +844,11 @@ fn parse_hex_str_to_raw_value(hex_str: &str) -> Result<middleware::RawValue, Pro
Ok(middleware::RawValue(v))
}

fn parse_hex_str_to_pod_id(hex_str: &str) -> Result<middleware::PodId, ProcessorError> {
let raw = parse_hex_str_to_raw_value(hex_str)?;
Ok(middleware::PodId(raw.into()))
}

// Helper to resolve a wildcard name string to an indexed middleware::Wildcard
// based on an ordered list of names from the current scope (e.g., request or predicate def).
fn resolve_wildcard(
Expand Down
2 changes: 2 additions & 0 deletions src/middleware/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,8 @@ impl fmt::Display for PodId {
write!(f, "self")
} else if self.0 == EMPTY_HASH {
write!(f, "null")
} else if f.alternate() {
write!(f, "{:#}", self.0)
} else {
write!(f, "{}", self.0)
}
Expand Down
Loading