Skip to content

Commit 6846c41

Browse files
committed
Fix potential error caused by floating point errors
1 parent e40168e commit 6846c41

File tree

6 files changed

+63
-74
lines changed

6 files changed

+63
-74
lines changed

extrap/gui/GraphWidget.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from PySide6.QtCore import * # @UnusedWildImport
1717
from PySide6.QtGui import * # @UnusedWildImport
1818
from PySide6.QtWidgets import * # @UnusedWildImport
19-
2019
from extrap.gui.Utils import formatFormula
2120
from extrap.gui.Utils import formatNumber
2221
from extrap.gui.plots.AbstractPlotWidget import AbstractPlotWidget
@@ -569,9 +568,8 @@ def calculate_function(self, function, length_x_axis, x_min=None, x_max=None):
569568
number_of_x_points, x_list, x_values = self._calculate_evaluation_points(length_x_axis, x_min, x_max)
570569
else:
571570
number_of_x_points, x_list, x_values = self._calculate_evaluation_points(length_x_axis)
572-
previous = numpy.seterr(invalid='ignore', divide='ignore')
573-
y_list = function.evaluate(x_list).reshape(-1)
574-
numpy.seterr(**previous)
571+
with numpy.errstate(invalid='ignore', divide='ignore'):
572+
y_list = function.evaluate(x_list).reshape(-1)
575573
cord_list = self._create_drawing_iterator(x_values, y_list)
576574

577575
return cord_list
@@ -585,10 +583,9 @@ def calculate_aggregate_callpath_function(self, functions, length_x_axis):
585583

586584
y_list = numpy.zeros(number_of_x_points)
587585

588-
previous = numpy.seterr(invalid='ignore', divide='ignore')
589-
for function in functions:
590-
y_list += function.evaluate(x_list).reshape(-1)
591-
numpy.seterr(**previous)
586+
with numpy.errstate(invalid='ignore', divide='ignore'):
587+
for function in functions:
588+
y_list += function.evaluate(x_list).reshape(-1)
592589

593590
cord_list = self._create_drawing_iterator(x_values, y_list)
594591

@@ -741,43 +738,41 @@ def calculateMaxY(self, modelList):
741738
y = max(model.predictions)
742739
y_max = max(y, y_max)
743740

744-
previous = numpy.seterr(invalid='ignore', divide='ignore')
745-
746-
if self.combine_all_callpath:
747-
y_agg = 0
741+
with numpy.errstate(invalid='ignore', divide='ignore'):
742+
if self.combine_all_callpath:
743+
y_agg = 0
744+
for model in modelList:
745+
function = model.hypothesis.function
746+
y_agg = y_agg + function.evaluate(pv_list)
747+
y_max = max(y_agg, y_max)
748+
749+
pv_list[param] = 1
750+
y_agg = 0
751+
for model in modelList:
752+
function = model.hypothesis.function
753+
y = function.evaluate(pv_list)
754+
if math.isinf(y):
755+
y = max(model.predictions)
756+
y_agg += y
757+
y_max = max(y_agg, y_max)
758+
759+
# Check the value at the end of the displayed interval
748760
for model in modelList:
749761
function = model.hypothesis.function
750-
y_agg = y_agg + function.evaluate(pv_list)
751-
y_max = max(y_agg, y_max)
762+
y = function.evaluate(pv_list)
763+
if math.isinf(y):
764+
y = max(model.predictions)
765+
y_max = max(y, y_max)
752766

767+
# Check the value at the beginning of the displayed interval
753768
pv_list[param] = 1
754-
y_agg = 0
755769
for model in modelList:
756770
function = model.hypothesis.function
757771
y = function.evaluate(pv_list)
758772
if math.isinf(y):
759773
y = max(model.predictions)
760-
y_agg += y
761-
y_max = max(y_agg, y_max)
762-
763-
# Check the value at the end of the displayed interval
764-
for model in modelList:
765-
function = model.hypothesis.function
766-
y = function.evaluate(pv_list)
767-
if math.isinf(y):
768-
y = max(model.predictions)
769-
y_max = max(y, y_max)
770-
771-
# Check the value at the beginning of the displayed interval
772-
pv_list[param] = 1
773-
for model in modelList:
774-
function = model.hypothesis.function
775-
y = function.evaluate(pv_list)
776-
if math.isinf(y):
777-
y = max(model.predictions)
778-
y_max = max(y, y_max)
774+
y_max = max(y, y_max)
779775

780-
numpy.seterr(**previous)
781776
# Ensure that the maximum value is never too small
782777
if y_max < 0.000001:
783778
y_max = 1

extrap/gui/SelectorWidget.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from typing import Optional, Sequence, TYPE_CHECKING, Tuple
1212

1313
import numpy
14-
from PySide6.QtCore import Slot
1514
from PySide6.QtWidgets import * # @UnusedWildImport
16-
1715
from extrap.entities.calltree import Node
1816
from extrap.entities.metric import Metric
1917
from extrap.entities.model import Model
@@ -346,9 +344,9 @@ def update_min_max_value(self):
346344
param_value_list = self.getParameterValues()
347345
call_tree = experiment.call_tree
348346
nodes = call_tree.get_nodes()
349-
previous = numpy.seterr(divide='ignore', invalid='ignore')
350-
value_list = self.iterate_children(model_set.models, param_value_list, nodes, selected_metric)
351-
numpy.seterr(**previous)
347+
with numpy.errstate(divide='ignore', invalid='ignore'):
348+
value_list = self.iterate_children(model_set.models, param_value_list, nodes, selected_metric)
349+
352350
if len(value_list) > 0:
353351
min_max_value = max(0.0, min(value_list)), max(0.0, max(value_list))
354352
self.min_value, self.max_value = min_max_value

extrap/gui/TreeModel.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from __future__ import annotations
99

1010
import copy
11-
import numpy
12-
from PySide6.QtCore import * # @UnusedWildImport
1311
from enum import Enum, auto
1412
from typing import Optional, TYPE_CHECKING, List, Callable, Any
1513

14+
import numpy
15+
from PySide6.QtCore import * # @UnusedWildImport
1616
from extrap.entities import calltree
1717
from extrap.entities.calltree import CallTree, Node
1818
from extrap.entities.experiment import Experiment
@@ -148,12 +148,13 @@ def data(self, index, role=None):
148148
return formatFormula(formula.to_string(*parameters))
149149
else:
150150
parameters = self.selector_widget.getParameterValues()
151-
previous = numpy.seterr(divide='ignore', invalid='ignore')
152-
if role == Qt.ToolTipRole and isinstance(model, SegmentedModel):
153-
res = _format_number_segmented_model(model, lambda m: m.hypothesis.function.evaluate(parameters))
154-
else:
155-
res = formatNumber(str(formula.evaluate(parameters)))
156-
numpy.seterr(**previous)
151+
with numpy.errstate(divide='ignore', invalid='ignore'):
152+
if role == Qt.ToolTipRole and isinstance(model, SegmentedModel):
153+
res = _format_number_segmented_model(model,
154+
lambda m: m.hypothesis.function.evaluate(parameters))
155+
else:
156+
res = formatNumber(str(formula.evaluate(parameters)))
157+
157158
return res
158159
elif index.column() == 4:
159160
if role == Qt.ToolTipRole and isinstance(model, SegmentedModel):
@@ -180,9 +181,8 @@ def data(self, index, role=None):
180181
def get_comparison_value(self, model):
181182
parameters = self.selector_widget.getParameterValues()
182183
formula = model.hypothesis.function
183-
previous = numpy.seterr(divide='ignore', invalid='ignore')
184-
value = formula.evaluate(parameters)
185-
numpy.seterr(**previous)
184+
with numpy.errstate(divide='ignore', invalid='ignore'):
185+
value = formula.evaluate(parameters)
186186
return value
187187

188188
def getSelectedModel(self, callpath) -> tuple[Optional[Model], Experiment]:

extrap/gui/plots/BaseGraphWidget.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
import matplotlib
1414
import numpy as np
1515
from PySide6.QtWidgets import QSizePolicy
16+
from extrap.util.formatting_helper import replace_method_parameters
1617
from matplotlib import patches as mpatches
1718
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
1819
from matplotlib.figure import Figure
1920
from mpl_toolkits.mplot3d import Axes3D
2021

21-
from extrap.util.formatting_helper import replace_method_parameters
22-
2322
if TYPE_CHECKING:
2423
from extrap.gui.MainWidget import MainWidget
2524

@@ -105,15 +104,14 @@ def calculate_z_models(self, maxX, maxY, model_list, max_z=0):
105104
# Get the z value for the x and y value
106105
z_List = list()
107106
Z_List = list()
108-
previous = np.seterr(invalid='ignore', divide='ignore')
109-
for model in model_list:
110-
function = model.hypothesis.function
111-
zs = self.calculate_z_optimized(X, Y, function)
112-
Z = zs.reshape(X.shape)
113-
z_List.append(zs)
114-
Z_List.append(Z)
115-
max_z = max(max_z, np.max(zs[np.logical_not(np.isinf(zs))]))
116-
np.seterr(**previous)
107+
with np.errstate(invalid='ignore', divide='ignore'):
108+
for model in model_list:
109+
function = model.hypothesis.function
110+
zs = self.calculate_z_optimized(X, Y, function)
111+
Z = zs.reshape(X.shape)
112+
z_List.append(zs)
113+
Z_List.append(Z)
114+
max_z = max(max_z, np.max(zs[np.logical_not(np.isinf(zs))]))
117115
for z, Z in zip(z_List, Z_List):
118116
z[np.isinf(z)] = max_z
119117
Z[np.isinf(Z)] = max_z

extrap/modelers/multi_parameter/multi_parameter_modeler.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from typing import Sequence
1212

1313
import numpy as np
14-
1514
from extrap.entities.coordinate import Coordinate
1615
from extrap.entities.functions import ConstantFunction
1716
from extrap.entities.functions import MultiParameterFunction
@@ -142,10 +141,9 @@ def make_measurement(c, ms: Sequence[Measurement]):
142141
warnings.warn(f"Could not use all measurement points. At least {self.min_measurement_points ** 2} "
143142
f"measurements are needed; one for each combination of parameters.")
144143

145-
previous = np.seterr(invalid='ignore')
146-
combined_measurements = [[make_measurement(c, ms) for c, ms in grp.items() if ms]
147-
for p, grp in enumerate(result_groups)]
148-
np.seterr(**previous)
144+
with np.errstate(invalid='ignore'):
145+
combined_measurements = [[make_measurement(c, ms) for c, ms in grp.items() if ms]
146+
for p, grp in enumerate(result_groups)]
149147

150148
return combined_measurements
151149

extrap/modelers/single_parameter/abstract_base.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,13 @@ def compare_hypotheses(self, old: Hypothesis, new: SingleParameterHypothesis, me
5050
# get the compound terms of the new hypothesis
5151
compound_terms = new.function.compound_terms
5252

53-
previous = numpy.seterr(divide='ignore', invalid='ignore')
54-
# for all compound terms check if they are smaller than minimum allowed contribution
55-
for term in compound_terms:
56-
# ignore this hypothesis, since one of the terms contributes less than epsilon to the function
57-
if term.coefficient == 0 or new.calc_term_contribution(term, measurements) < self.epsilon:
58-
return False
59-
numpy.seterr(**previous)
53+
with numpy.errstate(divide='ignore', invalid='ignore'):
54+
# for all compound terms check if they are smaller than minimum allowed contribution
55+
for term in compound_terms:
56+
# ignore this hypothesis, since one of the terms contributes less than epsilon to the function
57+
if (term.coefficient == 0
58+
or new.calc_term_contribution(term, measurements) < self.minimum_term_contribution):
59+
return False
6060

6161
# print smapes in debug mode
6262
logging.debug("next hypothesis SMAPE: %g RSS: %g", new.SMAPE, new.RSS)

0 commit comments

Comments
 (0)