|
1 | | -use proc_macro2::Span; |
2 | | -use proc_macro2::TokenStream as TokenStream2; |
3 | | -use quote::quote; |
4 | | -use quote::ToTokens; |
5 | | -use syn::parse_quote; |
6 | | -use syn::ExprCast; |
7 | | -use syn::{parse::Parse, punctuated::Punctuated, Expr, Ident, Token}; |
| 1 | +use proc_macro2::{Span, TokenStream as TokenStream2}; |
| 2 | +use quote::{format_ident, quote, ToTokens}; |
| 3 | +use syn::{ |
| 4 | + parse::Parse, parse_quote, punctuated::Punctuated, spanned::Spanned, Expr, ExprCast, Ident, |
| 5 | + Token, |
| 6 | +}; |
8 | 7 |
|
9 | 8 | pub struct Factor { |
10 | 9 | residual: Expr, |
@@ -81,10 +80,30 @@ impl Parse for Factor { |
81 | 80 | let m = quote!(factrs::noise); |
82 | 81 | match &input[2] { |
83 | 82 | Expr::Cast(ExprCast { expr, ty, .. }) => { |
84 | | - match ty.to_token_stream().to_string().as_str() { |
85 | | - "std" => Some(parse_quote!(#m::GaussianNoise::from_scalar_sigma(#expr))), |
86 | | - "cov" => Some(parse_quote!(#m::GaussianNoise::from_scalar_cov(#expr))), |
| 83 | + // Make sure it's a cov or std cast |
| 84 | + let ty = match ty.to_token_stream().to_string().as_str() { |
| 85 | + "cov" => Ident::new("cov", ty.span()), |
| 86 | + "std" | "sigma" | "sig" => Ident::new("sigma", ty.span()), |
87 | 87 | _ => return Err(syn::Error::new_spanned(ty, "Unknown cast for noise")), |
| 88 | + }; |
| 89 | + |
| 90 | + // Check if it's a tuple or a single variable |
| 91 | + match expr.as_ref() { |
| 92 | + Expr::Tuple(t) => { |
| 93 | + if t.elems.len() != 2 { |
| 94 | + return Err(syn::Error::new_spanned( |
| 95 | + t, |
| 96 | + "Expected tuple with two elements for split std/cov", |
| 97 | + )); |
| 98 | + } |
| 99 | + let (a, b) = (&t.elems[0], &t.elems[1]); |
| 100 | + let func = format_ident!("from_split_{}", ty); |
| 101 | + Some(parse_quote!(#m::GaussianNoise::#func(#a, #b))) |
| 102 | + } |
| 103 | + _ => { |
| 104 | + let func = format_ident!("from_scalar_{}", ty); |
| 105 | + Some(parse_quote!(#m::GaussianNoise::#func(#expr))) |
| 106 | + } |
88 | 107 | } |
89 | 108 | } |
90 | 109 | Expr::Infer(_) => Some(parse_quote!(#m::UnitNoise)), |
|
0 commit comments