@@ -692,35 +692,36 @@ def _repr_latex_(self) -> str:
692692 lower_bound , upper_bound = pspace (random_symbol ).domain .set .args [:2 ]
693693 integral = sympy .Integral (integral , (random_symbol , lower_bound , upper_bound ))
694694
695- s = latex (integral , mode = 'plain' , order = None )
695+ output = latex (integral , mode = 'plain' , order = None )
696696 else :
697- s = latex (profit_function , mode = 'plain' , order = None )
697+ output = latex (profit_function , mode = 'plain' , order = None )
698698
699- s = s .replace ('F_{0}' , 'F_{0}(T)' ).replace ('F_{1}' , 'F_{1}(T)' )
700- return f'$\\ displaystyle { s } $'
699+ output = output .replace ('F_{0}' , 'F_{0}(T)' ).replace ('F_{1}' , 'F_{1}(T)' )
701700 elif self .kind == 'cost' :
702- y , s , i , N = sympy .symbols ('y s i N' ) # noqa: N806
703- cost_function = (1 / N ) * sympy .Sum (
704- y * (s * self .tp_cost + (1 - s ) * self .fn_cost ) + (1 - y ) * ((1 - s ) * self .tn_cost + s * self .fp_cost ),
705- (i , 0 , N ),
706- )
701+ i , N = sympy .symbols ('i N' ) # noqa: N806
702+ cost_function = (1 / N ) * sympy .Sum (self ._format_cost_function (), (i , 0 , N ))
707703
708704 for symbol in cost_function .free_symbols :
709705 if symbol != N :
710706 cost_function = cost_function .subs (symbol , str (symbol ) + '_i' )
711707
712- s = latex (cost_function , mode = 'plain' , order = None )
713- return f'$\\ displaystyle { s } $'
708+ output = latex (cost_function , mode = 'plain' , order = None )
714709 elif self .kind == 'savings' :
715- y , s , i , N , c0 , c1 = sympy .symbols ('y s i N Cost_{0} Cost_{1}' ) # noqa: N806
716- cost_function = (1 / (N * sympy .Min (c0 , c1 ))) * sympy .Sum (
717- y * (s * self .tp_cost + (1 - s ) * self .fn_cost ) + (1 - y ) * ((1 - s ) * self .tn_cost + s * self .fp_cost ),
718- (i , 0 , N ),
719- )
710+ i , N , c0 , c1 = sympy .symbols ('i N Cost_{0} Cost_{1}' ) # noqa: N806
711+ savings_function = (1 / (N * sympy .Min (c0 , c1 ))) * sympy .Sum (self ._format_cost_function (), (i , 0 , N ))
720712
721- for symbol in cost_function .free_symbols :
713+ for symbol in savings_function .free_symbols :
722714 if symbol not in {N , c0 , c1 }:
723- cost_function = cost_function .subs (symbol , str (symbol ) + '_i' )
715+ savings_function = savings_function .subs (symbol , str (symbol ) + '_i' )
716+
717+ output = latex (savings_function , mode = 'plain' , order = None )
718+ else :
719+ return repr (self )
720+ return f'$\\ displaystyle { output } $'
724721
725- s = latex (cost_function , mode = 'plain' , order = None )
726- return f'$\\ displaystyle { s } $'
722+ def _format_cost_function (self ):
723+ y , s = sympy .symbols ('y s' )
724+ cost_function = y * (s * self .tp_cost + (1 - s ) * self .fn_cost ) + (1 - y ) * (
725+ (1 - s ) * self .tn_cost + s * self .fp_cost
726+ )
727+ return cost_function
0 commit comments