Skip to content

Commit 093ce46

Browse files
dsavoiunsmith-
andauthored
Add flow: wrap option to handle circular axes (#276)
* New flow behavior `wrap`. * Add test case for `wrap` flow behavior. * Factor out flow behavior parsing to helper function. * Correct arg type --------- Co-authored-by: Nick Smith <[email protected]>
1 parent 6692fb2 commit 093ce46

File tree

4 files changed

+56
-23
lines changed

4 files changed

+56
-23
lines changed

include/correction.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class HashPRNG {
175175
};
176176

177177
// common internal for Binning and MultiBinning
178-
enum class _FlowBehavior {value, clamp, error};
178+
enum class _FlowBehavior {value, clamp, error, wrap};
179179

180180
using _NonUniformBins = std::vector<double>;
181181

src/correction.cc

+42-20
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@
2525

2626
using namespace correction;
2727

28+
//! helper function for parsing flow behavior from string
29+
_FlowBehavior parse_flowbehavior(const rapidjson::Value& flowbehavior) {
30+
if ( flowbehavior == "clamp" ) {
31+
return _FlowBehavior::clamp;
32+
}
33+
else if ( flowbehavior == "error" ) {
34+
return _FlowBehavior::error;
35+
}
36+
else if ( flowbehavior == "wrap" ) {
37+
return _FlowBehavior::wrap;
38+
}
39+
else {
40+
return _FlowBehavior::value;
41+
}
42+
}
43+
2844
class correction::JSONObject {
2945
public:
3046
JSONObject(rapidjson::Value::ConstObject&& json) : json_(json) { }
@@ -149,20 +165,33 @@ namespace {
149165
return bins->n; // the default value is stored at the end of the content array, after the last bin
150166
case _FlowBehavior::clamp:
151167
return value < bins->low ? 0 : bins->n - 1; // assuming we always have at least 1 bin
168+
case _FlowBehavior::wrap:
169+
break;
152170
case _FlowBehavior::error:
153171
const std::string belowOrAbove = value < bins->low ? "below" : "above";
154172
auto msg = "Index " + belowOrAbove + " bounds in " + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value);
155173
throw std::runtime_error(std::move(msg));
156174
}
157175
}
158176

159-
std::size_t binIdx = bins->n * ((value - bins->low) / (bins->high - bins->low));
177+
double norm_value = ((value - bins->low) / (bins->high - bins->low));
178+
if (flow == _FlowBehavior::wrap) {
179+
norm_value -= std::floor(norm_value);
180+
}
181+
std::size_t binIdx = bins->n * norm_value;
160182
return binIdx;
161183
}
162184

163185
// otherwise we have non-uniform binning
164186
using namespace std::string_literals;
165187
const auto bins = std::get<_NonUniformBins>(bins_);
188+
if ( flow == _FlowBehavior::wrap ) {
189+
double low = bins[0];
190+
double high = bins[bins.size() - 1];
191+
double norm_value = (value - low) / (high - low);
192+
norm_value -= std::floor(norm_value);
193+
value = low + norm_value * (high - low);
194+
}
166195

167196
auto it = std::upper_bound(std::begin(bins), std::end(bins), value);
168197
if ( it == std::begin(bins) ) { // underflow
@@ -172,6 +201,9 @@ namespace {
172201
else if ( flow == _FlowBehavior::error ) {
173202
throw std::runtime_error("Index below bounds in "s + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value));
174203
}
204+
else if ( flow == _FlowBehavior::wrap ) {
205+
throw std::logic_error("I should not have ever seen an underflow");
206+
}
175207
else { // clamp
176208
it++;
177209
}
@@ -183,6 +215,9 @@ namespace {
183215
else if ( flow == _FlowBehavior::error ) {
184216
throw std::runtime_error("Index above bounds in "s + name + " for input argument " + std::to_string(variableIdx) + " value: " + std::to_string(value));
185217
}
218+
else if ( flow == _FlowBehavior::wrap ) {
219+
throw std::logic_error("I should not have ever seen an overflow");
220+
}
186221
else { // clamp
187222
it--;
188223
}
@@ -461,15 +496,9 @@ Binning::Binning(const JSONObject& json, const Correction& context)
461496
}
462497
Content default_value{0.};
463498
const auto& flowbehavior = json.getRequiredValue("flow");
464-
if ( flowbehavior == "clamp" ) {
465-
flow_ = _FlowBehavior::clamp;
466-
}
467-
else if ( flowbehavior == "error" ) {
468-
flow_ = _FlowBehavior::error;
469-
}
470-
else {
471-
flow_ = _FlowBehavior::value;
472-
default_value = resolve_content(flowbehavior, context);
499+
flow_ = parse_flowbehavior(flowbehavior);
500+
if (flow_ == _FlowBehavior::value) {
501+
default_value = resolve_content(flowbehavior, context);
473502
}
474503

475504
// set bin contents
@@ -540,16 +569,9 @@ MultiBinning::MultiBinning(const JSONObject& json, const Correction& context)
540569
}
541570

542571
const auto& flowbehavior = json.getRequiredValue("flow");
543-
if ( flowbehavior == "clamp" ) {
544-
flow_ = _FlowBehavior::clamp;
545-
}
546-
else if ( flowbehavior == "error" ) {
547-
flow_ = _FlowBehavior::error;
548-
}
549-
else {
550-
flow_ = _FlowBehavior::value;
551-
// store default value at end of content array
552-
content_.push_back(resolve_content(flowbehavior, context));
572+
flow_ = parse_flowbehavior(flowbehavior);
573+
if (flow_ == _FlowBehavior::value) {
574+
content_.push_back(resolve_content(flowbehavior, context));
553575
}
554576
}
555577

src/correctionlib/schemav2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ class Binning(Model):
232232
description="Edges of the binning, either as a list of monotonically increasing floats or as an instance of UniformBinning. edges[i] <= x < edges[i+1] => f(x, ...) = content[i](...)"
233233
)
234234
content: List[Content]
235-
flow: Union[Content, Literal["clamp", "error"]] = Field(
235+
flow: Union[Content, Literal["clamp", "error", "wrap"]] = Field(
236236
description="Overflow behavior for out-of-bounds values"
237237
)
238238

@@ -287,7 +287,7 @@ class MultiBinning(Model):
287287
to the element at i0 in dimension 0, i1 in dimension 1, etc. and d0 = len(edges[0])-1, etc.
288288
"""
289289
)
290-
flow: Union[Content, Literal["clamp", "error"]] = Field(
290+
flow: Union[Content, Literal["clamp", "error", "wrap"]] = Field(
291291
description="Overflow behavior for out-of-bounds values"
292292
)
293293

tests/test_core.py

+11
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,17 @@ def binning(flow, uniform=True):
751751
assert corr.evaluate(2.9) == 2.0
752752
assert corr.evaluate(3.0) == 42.0
753753

754+
corr = binning(flow="wrap", uniform=use_uniform_binning)
755+
assert corr.evaluate(-3.0) == 1.0
756+
assert corr.evaluate(-2.9) == 1.0
757+
assert corr.evaluate(-1.0) == 2.0
758+
assert corr.evaluate(0.0) == 1.0
759+
assert corr.evaluate(1.0) == 1.0 if use_uniform_binning else 2.0
760+
assert corr.evaluate(2.9) == 2.0
761+
assert corr.evaluate(3.0) == 1.0
762+
assert corr.evaluate(4.6) == 2.0
763+
assert corr.evaluate(6.1) == 1.0
764+
754765
def multibinning(flow, uniform=True):
755766
if uniform:
756767
edges_x = schema.UniformBinning(n=2, low=0.0, high=3.0)

0 commit comments

Comments
 (0)