diff --git a/tools/hermes/src/main.rs b/tools/hermes/src/main.rs index cebb846de5..a50f8c1a6d 100644 --- a/tools/hermes/src/main.rs +++ b/tools/hermes/src/main.rs @@ -1,8 +1,9 @@ mod errors; mod parse; +mod transform; mod ui_test_shim; -use std::{env, fs, path::PathBuf, process::exit}; +use std::{env, path::PathBuf, process::exit}; fn main() { if env::var("HERMES_UI_TEST_MODE").is_ok() { @@ -17,23 +18,27 @@ fn main() { } let file_path = PathBuf::from(&args[1]); - let source = match fs::read_to_string(&file_path) { - Ok(s) => s, - Err(e) => { - eprintln!("Error reading file: {}", e); - exit(1); - } - }; let mut has_errors = false; - parse::visit_hermes_items_in_file(&file_path, &source, |res| { + let mut edits = Vec::new(); + let res = parse::read_file_and_visit_hermes_items(&file_path, |_src, res| { if let Err(e) = res { has_errors = true; eprint!("{:?}", miette::Report::new(e)); + } else if let Ok(item) = res { + transform::append_edits(&item, &mut edits); } }); + let source = res.unwrap_or_else(|e| { + eprintln!("Error parsing file: {}", e); + exit(1); + }); + if has_errors { exit(1); } + + let mut source = source.into_bytes(); + transform::apply_edits(&mut source, &edits); } diff --git a/tools/hermes/src/parse.rs b/tools/hermes/src/parse.rs index 44cf450dfd..2fbf2616a3 100644 --- a/tools/hermes/src/parse.rs +++ b/tools/hermes/src/parse.rs @@ -1,10 +1,15 @@ -use std::path::{Path, PathBuf}; +use std::{ + fs, io, + ops::Range, + path::{Path, PathBuf}, +}; use log::{debug, trace}; use miette::{NamedSource, SourceSpan}; +use proc_macro2::Span; use syn::{ - visit::Visit, Attribute, Error, Expr, ItemEnum, ItemFn, ItemImpl, ItemMod, ItemStruct, - ItemTrait, ItemUnion, Lit, Meta, + spanned::Spanned as _, visit::Visit, Attribute, Error, Expr, ItemEnum, ItemFn, ItemImpl, + ItemMod, ItemStruct, ItemTrait, ItemUnion, Lit, Meta, }; use crate::errors::HermesError; @@ -30,7 +35,7 @@ impl std::error::Error for ParseError {} /// The item from the original source code. #[derive(Clone, Debug)] -enum ParsedItem { +pub enum ParsedItem { Fn(ItemFn), Struct(ItemStruct), Enum(ItemEnum), @@ -56,7 +61,7 @@ impl ParsedItem { /// A complete parsed item including its module path and the extracted Lean block. #[derive(Debug)] pub struct ParsedLeanItem { - item: ParsedItem, + pub item: ParsedItem, module_path: Vec, lean_block: String, source_file: Option, @@ -67,9 +72,9 @@ pub struct ParsedLeanItem { /// /// If parsing fails, or if any item has multiple Lean blocks, the callback is /// invoked with an `Err`. -fn visit_hermes_items(source: &str, f: F) +pub fn visit_hermes_items(source: &str, f: F) where - F: FnMut(Result), + F: FnMut(&str, Result), { visit_hermes_items_internal(source, None, f) } @@ -77,16 +82,18 @@ where /// Parses the given Rust source code from a file path and invokes the callback `f` /// for each item annotated with a `/// ```lean` block. Parsing errors and generated /// items will be associated with this file path. -pub fn visit_hermes_items_in_file(path: &Path, source: &str, f: F) +pub fn read_file_and_visit_hermes_items(path: &Path, f: F) -> Result where - F: FnMut(Result), + F: FnMut(&str, Result), { - visit_hermes_items_internal(source, Some(path.to_path_buf()), f) + let source = fs::read_to_string(path).expect("Failed to read file"); + visit_hermes_items_internal(&source, Some(path.to_path_buf()), f); + Ok(source) } fn visit_hermes_items_internal(source: &str, source_file: Option, mut f: F) where - F: FnMut(Result), + F: FnMut(&str, Result), { trace!("Parsing source code into syn::File"); let file_name = { @@ -109,11 +116,14 @@ where } Err(e) => { debug!("Failed to parse source code: {}", e); - f(Err(HermesError::SynError { - src: named_source.clone(), - span: span_to_miette(e.span()), - msg: e.to_string(), - })); + f( + source, + Err(HermesError::SynError { + src: named_source.clone(), + span: span_to_miette(e.span()), + msg: e.to_string(), + }), + ); return; } }; @@ -141,31 +151,40 @@ struct HermesVisitor { impl HermesVisitor where - F: FnMut(Result), + F: FnMut(&str, Result), { - fn check_and_add(&mut self, item: ParsedItem) { + fn check_and_add(&mut self, item: ParsedItem, span: Span) { + let Range { start, end } = span.byte_range(); + let source = &self.source_code.as_str()[start..end]; + let attrs = item.attrs(); trace!("Checking item in module path `{:?}` for ```lean block", self.current_path); match extract_lean_block(attrs) { Ok(Some(lean_block)) => { debug!("Found valid ```lean block for item in `{:?}`", self.current_path); - (self.callback)(Ok(ParsedLeanItem { - item, - module_path: self.current_path.clone(), - lean_block, - source_file: self.source_file.clone(), - })); + (self.callback)( + source, + Ok(ParsedLeanItem { + item, + module_path: self.current_path.clone(), + lean_block, + source_file: self.source_file.clone(), + }), + ); } Ok(None) => { trace!("No ```lean block found for item"); } // Skip item Err(e) => { debug!("Error extracting ```lean block: {}", e); - (self.callback)(Err(HermesError::DocBlockError { - src: self.named_source.clone(), - span: span_to_miette(e.span()), - msg: e.to_string(), - })); + (self.callback)( + source, + Err(HermesError::DocBlockError { + src: self.named_source.clone(), + span: span_to_miette(e.span()), + msg: e.to_string(), + }), + ); } } } @@ -173,7 +192,7 @@ where impl<'ast, F> Visit<'ast> for HermesVisitor where - F: FnMut(Result), + F: FnMut(&str, Result), { fn visit_item_mod(&mut self, node: &'ast ItemMod) { let mod_name = node.ident.to_string(); @@ -186,37 +205,37 @@ where fn visit_item_fn(&mut self, node: &'ast ItemFn) { trace!("Visiting Fn {}", node.sig.ident); - self.check_and_add(ParsedItem::Fn(node.clone())); + self.check_and_add(ParsedItem::Fn(node.clone()), node.span()); syn::visit::visit_item_fn(self, node); } fn visit_item_struct(&mut self, node: &'ast ItemStruct) { trace!("Visiting Struct {}", node.ident); - self.check_and_add(ParsedItem::Struct(node.clone())); + self.check_and_add(ParsedItem::Struct(node.clone()), node.span()); syn::visit::visit_item_struct(self, node); } fn visit_item_enum(&mut self, node: &'ast ItemEnum) { trace!("Visiting Enum {}", node.ident); - self.check_and_add(ParsedItem::Enum(node.clone())); + self.check_and_add(ParsedItem::Enum(node.clone()), node.span()); syn::visit::visit_item_enum(self, node); } fn visit_item_union(&mut self, node: &'ast ItemUnion) { trace!("Visiting Union {}", node.ident); - self.check_and_add(ParsedItem::Union(node.clone())); + self.check_and_add(ParsedItem::Union(node.clone()), node.span()); syn::visit::visit_item_union(self, node); } fn visit_item_trait(&mut self, node: &'ast ItemTrait) { trace!("Visiting Trait {}", node.ident); - self.check_and_add(ParsedItem::Trait(node.clone())); + self.check_and_add(ParsedItem::Trait(node.clone()), node.span()); syn::visit::visit_item_trait(self, node); } fn visit_item_impl(&mut self, node: &'ast ItemImpl) { trace!("Visiting Impl"); - self.check_and_add(ParsedItem::Impl(node.clone())); + self.check_and_add(ParsedItem::Impl(node.clone()), node.span()); syn::visit::visit_item_impl(self, node); } } @@ -319,9 +338,9 @@ mod tests { use super::*; - fn parse_to_vec(code: &str) -> Vec> { + fn parse_to_vec(code: &str) -> Vec<(String, Result)> { let mut items = Vec::new(); - visit_hermes_items(code, |res| items.push(res)); + visit_hermes_items(code, |src, res| items.push((src.to_string(), res))); items } @@ -334,7 +353,15 @@ mod tests { fn foo() {} "#; let items = parse_to_vec(code); - let item = items.into_iter().next().unwrap().unwrap(); + let (src, res) = items.into_iter().next().unwrap(); + let item = res.unwrap(); + assert_eq!( + src, + "/// ```lean + /// theorem foo : True := by trivial + /// ``` + fn foo() {}" + ); assert!(matches!(item.item, ParsedItem::Fn(_))); assert_eq!(item.lean_block, " ```lean\n theorem foo : True := by trivial\n ```\n"); assert!(item.source_file.is_none()); @@ -352,7 +379,18 @@ mod tests { fn foo() {} "#; let items = parse_to_vec(code); - let err = items.into_iter().next().unwrap().unwrap_err(); + let (src, res) = items.into_iter().next().unwrap(); + let err = res.unwrap_err(); + assert_eq!( + src, + "/// ```lean + /// a + /// ``` + /// ```lean + /// b + /// ``` + fn foo() {}" + ); assert!(err.to_string().contains("Multiple lean blocks")); } @@ -364,7 +402,14 @@ mod tests { fn foo() {} "#; let items = parse_to_vec(code); - let err = items.into_iter().next().unwrap().unwrap_err(); + let (src, res) = items.into_iter().next().unwrap(); + let err = res.unwrap_err(); + assert_eq!( + src, + "/// ```lean + /// a + fn foo() {}" + ); assert!(err.to_string().contains("Unclosed")); } @@ -380,7 +425,14 @@ mod tests { } "#; let items = parse_to_vec(code); - let item = items.into_iter().next().unwrap().unwrap(); + let (src, res) = items.into_iter().next().unwrap(); + let item = res.unwrap(); + assert_eq!( + src, + "/// ```lean + /// ``` + fn foo() {}" + ); assert_eq!(item.module_path, vec!["a", "b"]); } @@ -392,9 +444,20 @@ mod tests { fn foo() {} "#; let mut items = Vec::new(); - visit_hermes_items_in_file(Path::new("src/foo.rs"), code, |res| items.push(res)); - let err = items.into_iter().next().unwrap().unwrap_err(); - let rep = format!("{:?}", miette::Report::new(err)); + visit_hermes_items_internal( + code, + Some(Path::new("src/foo.rs").to_path_buf()), + |source: &str, res| items.push((source.to_string(), res)), + ); + let (src, res) = items.into_iter().next().unwrap(); + assert_eq!( + src, + "/// ```lean + /// a + fn foo() {}" + ); + + let rep = format!("{:?}", miette::Report::new(res.unwrap_err())); assert!(rep.contains("src/foo.rs")); assert!(rep.contains("Unclosed")); } @@ -417,7 +480,7 @@ mod c { } "; let mut items = Vec::new(); - visit_hermes_items(source, |res| items.push(res)); + visit_hermes_items(source, |_src, res| items.push(res)); let i1 = items[0].as_ref().unwrap(); let i2 = items[1].as_ref().unwrap(); @@ -464,8 +527,8 @@ mod c { let path = std::path::Path::new("src/foo.rs"); let mut items = Vec::new(); - visit_hermes_items_in_file(path, code1, |res| items.push(res)); - visit_hermes_items_in_file(path, code2, |res| items.push(res)); + visit_hermes_items_internal(code1, Some(path.to_path_buf()), |_src, res| items.push(res)); + visit_hermes_items_internal(code2, Some(path.to_path_buf()), |_src, res| items.push(res)); let mut report_string = String::new(); for err in items.into_iter().filter_map(|r| r.err()) { diff --git a/tools/hermes/src/transform.rs b/tools/hermes/src/transform.rs new file mode 100644 index 0000000000..06e2e5abac --- /dev/null +++ b/tools/hermes/src/transform.rs @@ -0,0 +1,166 @@ +use proc_macro2::Span; +use syn::spanned::Spanned; + +use crate::parse::{ParsedItem, ParsedLeanItem}; + +/// Appends the spans of text that should be blanked out in the shadow crate. +/// +/// For `unsafe` functions with Hermes annotations, this targets: +/// 1. The `unsafe` keyword (to make the function signature "safe" for Aeneas). +/// 2. The entire function block (to remove the unverified implementation). +pub fn append_edits(item: &ParsedLeanItem, edits: &mut Vec) { + if let ParsedItem::Fn(func) = &item.item { + if let Some(unsafety) = &func.sig.unsafety { + // 1. Mark the `unsafe` keyword for blanking. + // Result: `unsafe fn` -> ` fn` + edits.push(unsafety.span()); + + // TODO: + // - Only blank bodies for functions which are modeled. + // - Figure out what to replace these bodies with. + edits.push(func.block.span()); + } + } +} + +/// Applies a set of redaction edits to the source buffer in-place. +/// +/// For each span in `edits`, this function replaces all characters with spaces +/// (`0x20`), except for newline characters (`0x0A` and `0x0D`), which are +/// preserved to maintain line numbering and Windows compatibility. This allows +/// the shadow crate to report errors on spans that align with the user's +/// original file. +/// +/// # Panics +/// +/// Panics if any span in `edits` is not in-bounds of `buffer`. +pub fn apply_edits(buffer: &mut [u8], edits: &[Span]) { + for span in edits { + for byte in &mut buffer[span.byte_range()] { + if !matches!(*byte, b'\n' | b'\r') { + *byte = b' '; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_apply_edits_preserves_newlines() { + let source = b"unsafe fn test() {\r\n let a = 1;\n let b = 2;\r\n}"; + let mut buffer = source.to_vec(); + + let file = syn::parse_file(std::str::from_utf8(source).unwrap()).unwrap(); + let func = match &file.items[0] { + syn::Item::Fn(f) => f, + _ => panic!("Expected function"), + }; + + let edits = vec![func.sig.unsafety.unwrap().span(), func.block.span()]; + + apply_edits(&mut buffer, &edits); + + let expected = b" fn test() \r\n \n \r\n ".to_vec(); + assert_eq!(std::str::from_utf8(&buffer).unwrap(), std::str::from_utf8(&expected).unwrap()); + } + + #[test] + fn test_apply_edits_with_parsed_item() { + let source = " + /// ```lean + /// theorem foo : True := by trivial + /// ``` + unsafe fn foo() { + let x = 1; + } + "; + let mut items = Vec::new(); + crate::parse::visit_hermes_items(source, |_src, res| items.push(res)); + + let item = items.into_iter().next().unwrap().unwrap(); + let mut edits = Vec::new(); + append_edits(&item, &mut edits); + + let mut buffer = source.as_bytes().to_vec(); + apply_edits(&mut buffer, &edits); + + let expected = " + /// ```lean + /// theorem foo : True := by trivial + /// ``` + fn foo() + + + "; + assert_eq!(std::str::from_utf8(&buffer).unwrap(), expected); + } + + #[test] + fn test_apply_edits_multibyte_utf8() { + // Source contains multi-byte characters: + // - Immediately preceding `unsafe`: `/*前*/` + // - Immediately following `{`: `/*后*/` + // - Immediately before `}`: `/*前*/` + let source = " + fn safe() {} + /// ```lean + /// ``` + /*前*/unsafe fn foo() {/*后*/ + let x = '中'; + /*前*/} + "; + let mut items = Vec::new(); + crate::parse::scan_compilation_unit(source, |_src, res| items.push(res)); + + // Find the unsafe function (should be the first item, as safe() is skipped) + let item = items.into_iter().find(|i| { + if let Ok(ParsedLeanItem { item: ParsedItem::Fn(f), .. }) = i { + f.sig.ident == "foo" + } else { + false + } + }).unwrap().unwrap(); + + let mut edits = Vec::new(); + append_edits(&item, &mut edits); + + let mut buffer = source.as_bytes().to_vec(); + apply_edits(&mut buffer, &edits); + + // Expected whitespace lengths: + // Line 1: ` /*前*/unsafe fn foo() {/*后*/` + // - Indent: 12 spaces + // - `/*前*/`: 7 bytes (preserved) + // - `unsafe`: 6 bytes -> 6 spaces + // - ` fn foo() {`: 9 bytes (preserved) + // - `/*后*/`: 7 bytes -> 7 spaces + // + // Line 2: ` let x = '中';` + // - Indent: 16 spaces -> 16 spaces + // - `let x = '中';`: 14 bytes -> 14 spaces + // `let` (3) + ` ` (1) + `x` (1) + ` ` (1) + `=` (1) + ` ` (1) + `'` (1) + `中` (3) + `'` (1) + `;` (1) = 14 + // Total: 30 spaces + // + // Line 3: ` /*前*/}` + // - Indent 12 spaces -> 12 spaces + // - `/*前*/`: 7 bytes -> 7 spaces + // - `}`: 1 byte (preserved) + // Total: 19 spaces + `}` + + let line_with_unsafe = " /*前*/ fn foo() { "; + let line_body = " ".repeat(30); + let line_end = format!("{}}}", " ".repeat(19)); + + let expected = format!( + "\n fn safe() {{}}\n /// ```lean\n /// ```\n{}\n{}\n{}\n ", + line_with_unsafe, + line_body, + line_end + ); + + assert_eq!(std::str::from_utf8(&buffer).unwrap(), expected); + } +} diff --git a/tools/hermes/src/ui_test_shim.rs b/tools/hermes/src/ui_test_shim.rs index 94b45986fe..181821fabc 100644 --- a/tools/hermes/src/ui_test_shim.rs +++ b/tools/hermes/src/ui_test_shim.rs @@ -1,4 +1,4 @@ -use std::{env, fs, path::PathBuf, process::exit}; +use std::{env, path::PathBuf, process::exit}; use miette::Diagnostic as _; use serde::Serialize; @@ -34,15 +34,15 @@ pub fn run() { }); // Run logic with JSON emitter - let source = fs::read_to_string(&file_path).unwrap_or_default(); let mut has_errors = false; - parse::visit_hermes_items_in_file(&file_path, &source, |res| { + parse::read_file_and_visit_hermes_items(&file_path, |source, res| { if let Err(e) = res { has_errors = true; emit_rustc_json(&e, &source, file_path.to_str().unwrap()); } - }); + }) + .unwrap(); if has_errors { exit(1);