@@ -24,18 +24,22 @@ pub struct SanitizedPriceEstimator {
2424 inner : Arc < dyn PriceEstimating > ,
2525 bad_token_detector : Arc < dyn BadTokenDetecting > ,
2626 native_token : Address ,
27+ /// Enables the short-circuiting logic in case of sell and buy is the same
28+ is_estimating_native_price : bool ,
2729}
2830
2931impl SanitizedPriceEstimator {
3032 pub fn new (
3133 inner : Arc < dyn PriceEstimating > ,
3234 native_token : Address ,
3335 bad_token_detector : Arc < dyn BadTokenDetecting > ,
36+ is_estimating_native_price : bool ,
3437 ) -> Self {
3538 Self {
3639 inner,
3740 native_token,
3841 bad_token_detector,
42+ is_estimating_native_price,
3943 }
4044 }
4145
@@ -64,7 +68,10 @@ impl PriceEstimating for SanitizedPriceEstimator {
6468 self . handle_bad_tokens ( & query) . await ?;
6569
6670 // buy_token == sell_token => 1 to 1 conversion
67- if query. buy_token == query. sell_token {
71+ // Only in case of native price estimation.
72+ // For regular price estimation, the sell and buy tokens can
73+ // be the same and should be priced as usual
74+ if self . is_estimating_native_price && query. buy_token == query. sell_token {
6875 let estimation = Estimate {
6976 out_amount : query. in_amount . get ( ) . into_alloy ( ) ,
7077 gas : 0 ,
@@ -454,11 +461,12 @@ mod tests {
454461 }
455462 . boxed ( )
456463 } ) ;
457-
464+ let bad_token_detector = Arc :: new ( bad_token_detector ) ;
458465 let sanitized_estimator = SanitizedPriceEstimator {
459466 inner : Arc :: new ( wrapped_estimator) ,
460- bad_token_detector : Arc :: new ( bad_token_detector) ,
467+ bad_token_detector : bad_token_detector. clone ( ) ,
461468 native_token,
469+ is_estimating_native_price : true ,
462470 } ;
463471
464472 for ( query, expectation) in queries {
@@ -473,5 +481,110 @@ mod tests {
473481 }
474482 }
475483 }
484+
485+ let queries = [
486+ // Can be estimated by `sanitized_estimator` because `buy_token` and `sell_token` are
487+ // identical.
488+ (
489+ Query {
490+ verification : Default :: default ( ) ,
491+ sell_token : Address :: with_last_byte ( 1 ) ,
492+ buy_token : Address :: with_last_byte ( 1 ) ,
493+ in_amount : NonZeroU256 :: try_from ( 1 ) . unwrap ( ) ,
494+ kind : OrderKind :: Sell ,
495+ block_dependent : false ,
496+ timeout : HEALTHY_PRICE_ESTIMATION_TIME ,
497+ } ,
498+ Ok ( Estimate {
499+ out_amount : AlloyU256 :: ONE ,
500+ gas : 100 ,
501+ solver : Default :: default ( ) ,
502+ verified : true ,
503+ execution : Default :: default ( ) ,
504+ } ) ,
505+ ) ,
506+ (
507+ Query {
508+ verification : Default :: default ( ) ,
509+ sell_token : native_token,
510+ buy_token : native_token,
511+ in_amount : NonZeroU256 :: try_from ( 1 ) . unwrap ( ) ,
512+ kind : OrderKind :: Sell ,
513+ block_dependent : false ,
514+ timeout : HEALTHY_PRICE_ESTIMATION_TIME ,
515+ } ,
516+ Ok ( Estimate {
517+ out_amount : AlloyU256 :: ONE ,
518+ gas : 100 ,
519+ solver : Default :: default ( ) ,
520+ verified : true ,
521+ execution : Default :: default ( ) ,
522+ } ) ,
523+ ) ,
524+ ] ;
525+
526+ // SanitizedPriceEstimator will simply forward the Query in the sell=buy case
527+ // if it is not calculating native price
528+ let first_forwarded_query = queries[ 0 ] . 0 . clone ( ) ;
529+
530+ // SanitizedPriceEstimator will simply forward the Query if sell=buy of native
531+ // token case if it is not calculating the native price
532+ let second_forwarded_query = queries[ 1 ] . 0 . clone ( ) ;
533+
534+ let mut wrapped_estimator = MockPriceEstimating :: new ( ) ;
535+ wrapped_estimator
536+ . expect_estimate ( )
537+ . times ( 1 )
538+ . withf ( move |query| * * query == first_forwarded_query)
539+ . returning ( |_| {
540+ async {
541+ Ok ( Estimate {
542+ out_amount : AlloyU256 :: ONE ,
543+ gas : 100 ,
544+ solver : Default :: default ( ) ,
545+ verified : true ,
546+ execution : Default :: default ( ) ,
547+ } )
548+ }
549+ . boxed ( )
550+ } ) ;
551+ wrapped_estimator
552+ . expect_estimate ( )
553+ . times ( 1 )
554+ . withf ( move |query| * * query == second_forwarded_query)
555+ . returning ( |_| {
556+ async {
557+ Ok ( Estimate {
558+ out_amount : AlloyU256 :: ONE ,
559+ gas : 100 ,
560+ solver : Default :: default ( ) ,
561+ verified : true ,
562+ execution : Default :: default ( ) ,
563+ } )
564+ }
565+ . boxed ( )
566+ } ) ;
567+
568+ let sanitized_estimator_non_native = SanitizedPriceEstimator {
569+ inner : Arc :: new ( wrapped_estimator) ,
570+ bad_token_detector,
571+ native_token,
572+ is_estimating_native_price : false ,
573+ } ;
574+
575+ for ( query, expectation) in queries {
576+ let result = sanitized_estimator_non_native
577+ . estimate ( Arc :: new ( query) )
578+ . await ;
579+ match result {
580+ Ok ( estimate) => assert_eq ! ( estimate, expectation. unwrap( ) ) ,
581+ Err ( err) => {
582+ // we only compare the error variant; everything else would be a PITA
583+ let reported_error = std:: mem:: discriminant ( & err) ;
584+ let expected_error = std:: mem:: discriminant ( & expectation. unwrap_err ( ) ) ;
585+ assert_eq ! ( reported_error, expected_error) ;
586+ }
587+ }
588+ }
476589 }
477590}
0 commit comments