Skip to content

Commit e1b3768

Browse files
Refactor latex repr
1 parent 0dcb41e commit e1b3768

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

empulse/metrics/metric.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -692,35 +692,36 @@ def _repr_latex_(self) -> str:
692692
lower_bound, upper_bound = pspace(random_symbol).domain.set.args[:2]
693693
integral = sympy.Integral(integral, (random_symbol, lower_bound, upper_bound))
694694

695-
s = latex(integral, mode='plain', order=None)
695+
output = latex(integral, mode='plain', order=None)
696696
else:
697-
s = latex(profit_function, mode='plain', order=None)
697+
output = latex(profit_function, mode='plain', order=None)
698698

699-
s = s.replace('F_{0}', 'F_{0}(T)').replace('F_{1}', 'F_{1}(T)')
700-
return f'$\\displaystyle {s}$'
699+
output = output.replace('F_{0}', 'F_{0}(T)').replace('F_{1}', 'F_{1}(T)')
701700
elif self.kind == 'cost':
702-
y, s, i, N = sympy.symbols('y s i N') # noqa: N806
703-
cost_function = (1 / N) * sympy.Sum(
704-
y * (s * self.tp_cost + (1 - s) * self.fn_cost) + (1 - y) * ((1 - s) * self.tn_cost + s * self.fp_cost),
705-
(i, 0, N),
706-
)
701+
i, N = sympy.symbols('i N') # noqa: N806
702+
cost_function = (1 / N) * sympy.Sum(self._format_cost_function(), (i, 0, N))
707703

708704
for symbol in cost_function.free_symbols:
709705
if symbol != N:
710706
cost_function = cost_function.subs(symbol, str(symbol) + '_i')
711707

712-
s = latex(cost_function, mode='plain', order=None)
713-
return f'$\\displaystyle {s}$'
708+
output = latex(cost_function, mode='plain', order=None)
714709
elif self.kind == 'savings':
715-
y, s, i, N, c0, c1 = sympy.symbols('y s i N Cost_{0} Cost_{1}') # noqa: N806
716-
cost_function = (1 / (N * sympy.Min(c0, c1))) * sympy.Sum(
717-
y * (s * self.tp_cost + (1 - s) * self.fn_cost) + (1 - y) * ((1 - s) * self.tn_cost + s * self.fp_cost),
718-
(i, 0, N),
719-
)
710+
i, N, c0, c1 = sympy.symbols('i N Cost_{0} Cost_{1}') # noqa: N806
711+
savings_function = (1 / (N * sympy.Min(c0, c1))) * sympy.Sum(self._format_cost_function(), (i, 0, N))
720712

721-
for symbol in cost_function.free_symbols:
713+
for symbol in savings_function.free_symbols:
722714
if symbol not in {N, c0, c1}:
723-
cost_function = cost_function.subs(symbol, str(symbol) + '_i')
715+
savings_function = savings_function.subs(symbol, str(symbol) + '_i')
716+
717+
output = latex(savings_function, mode='plain', order=None)
718+
else:
719+
return repr(self)
720+
return f'$\\displaystyle {output}$'
724721

725-
s = latex(cost_function, mode='plain', order=None)
726-
return f'$\\displaystyle {s}$'
722+
def _format_cost_function(self):
723+
y, s = sympy.symbols('y s')
724+
cost_function = y * (s * self.tp_cost + (1 - s) * self.fn_cost) + (1 - y) * (
725+
(1 - s) * self.tn_cost + s * self.fp_cost
726+
)
727+
return cost_function

0 commit comments

Comments
 (0)