Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/openmc/cell.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ class Region {
//! BoundingBox if the particle is in a complex cell.
BoundingBox bounding_box_complex(vector<int32_t> postfix) const;

//! Enfource precedence: Parenthases, Complement, Intersection, Union
void add_precedence();
//! Enforce precedence between intersections and unions
void enforce_precedence();

//! Add parenthesis to enforce precedence
int64_t add_parentheses(int64_t start);
void add_parentheses(int64_t start);

//! Remove complement operators from the expression
void remove_complement_ops();
Expand Down
96 changes: 53 additions & 43 deletions src/cell.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ Region::Region(std::string region_spec, int32_t cell_id)
if (token == OP_UNION) {
simple_ = false;
// Ensure intersections have precedence over unions
add_precedence();
enforce_precedence();
break;
}
}
Expand Down Expand Up @@ -703,7 +703,7 @@ void Region::apply_demorgan(
//! precedence than unions using parentheses.
//==============================================================================

int64_t Region::add_parentheses(int64_t start)
void Region::add_parentheses(int64_t start)
{
int32_t start_token = expression_[start];
// Add left parenthesis and set new position to be after parenthesis
Expand All @@ -712,14 +712,6 @@ int64_t Region::add_parentheses(int64_t start)
}
expression_.insert(expression_.begin() + start - 1, OP_LEFT_PAREN);

// Keep track of return iterator distance. If we don't encounter a left
// parenthesis, we return an iterator corresponding to wherever the right
// parenthesis is inserted. If a left parenthesis is encountered, an iterator
// corresponding to the left parenthesis is returned. Also note that we keep
// track of a *distance* instead of an iterator because the underlying memory
// allocation may change.
std::size_t return_it_dist = 0;

// Add right parenthesis
// While the start iterator is within the bounds of infix
while (start + 1 < expression_.size()) {
Expand All @@ -733,7 +725,6 @@ int64_t Region::add_parentheses(int64_t start)
// in the region, when the operator is an intersection then include the
// operator and next surface
if (expression_[start] == OP_LEFT_PAREN) {
return_it_dist = start;
int depth = 1;
do {
start++;
Expand All @@ -750,54 +741,73 @@ int64_t Region::add_parentheses(int64_t start)
--start;
}
expression_.insert(expression_.begin() + start, OP_RIGHT_PAREN);
if (return_it_dist > 0) {
return return_it_dist;
} else {
return start - 1;
}
return;
}
}
}
// If we get here a right parenthesis hasn't been placed,
// return iterator
// If we get here a right parenthesis hasn't been placed
expression_.push_back(OP_RIGHT_PAREN);
if (return_it_dist > 0) {
return return_it_dist;
} else {
return start - 1;
}
}

//==============================================================================
//! Add parentheses to enforce operator precedence in region expressions
//!
//! This function ensures that intersection operators have higher precedence
//! than union operators by adding parentheses where needed. For example:
//! "1 2 | 3" becomes "(1 2) | 3"
//! "1 | 2 3" becomes "1 | (2 3)"
//!
//! The algorithm uses stacks to track the current operator type and its
//! position at each parenthesis depth level. When it encounters a different
//! operator at the same depth, it adds parentheses to group the
//! higher-precedence operations.
//==============================================================================

void Region::add_precedence()
void Region::enforce_precedence()
{
int32_t current_op = 0;
std::size_t current_dist = 0;
// Stack tracking the operator type at each depth (0 = no operator seen yet)
vector<int32_t> op_stack = {0};

for (int64_t i = 0; i < expression_.size(); i++) {
// Stack tracking where the operator sequence started at each depth
vector<std::size_t> pos_stack = {0};

for (int64_t i = 0; i < expression_.size(); ++i) {
int32_t token = expression_[i];

if (token == OP_LEFT_PAREN) {
// Entering a new parenthesis level - push new tracking state
op_stack.push_back(0);
pos_stack.push_back(0);
continue;
} else if (token == OP_RIGHT_PAREN) {
// Exiting a parenthesis level - pop tracking state (keep at least one)
if (op_stack.size() > 1) {
op_stack.pop_back();
pos_stack.pop_back();
}
continue;
}

if (token == OP_UNION || token == OP_INTERSECTION) {
if (current_op == 0) {
// Set the current operator if is hasn't been set
current_op = token;
current_dist = i;
} else if (token != current_op) {
// If the current operator doesn't match the token, add parenthesis to
// assert precedence
if (current_op == OP_INTERSECTION) {
i = add_parentheses(current_dist);
if (op_stack.back() == 0) {
// First operator at this depth - record it and its position
op_stack.back() = token;
pos_stack.back() = i;
} else if (token != op_stack.back()) {
// Encountered a different operator at the same depth - need to add
// parentheses to enforce precedence. Intersection has higher
// precedence, so we parenthesize the intersection terms.
if (op_stack.back() == OP_INTERSECTION) {
add_parentheses(pos_stack.back());
} else {
i = add_parentheses(i);
add_parentheses(i);
}
current_op = 0;
current_dist = 0;

// Restart the scan since we modified the expression
i = -1; // Will be incremented to 0 by the for loop
op_stack = {0};
pos_stack = {0};
}
} else if (token > OP_COMPLEMENT) {
// If the token is a parenthesis reset the current operator
current_op = 0;
current_dist = 0;
}
}
}
Expand Down
1 change: 1 addition & 0 deletions tests/cpp_unit_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set(TEST_NAMES
test_math
test_mcpl_stat_sum
test_mesh
test_region
# Add additional unit test files here
)

Expand Down
101 changes: 101 additions & 0 deletions tests/cpp_unit_tests/test_region.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include <catch2/catch_test_macros.hpp>

#include "openmc/cell.h"
#include "openmc/surface.h"

#include <pugixml.hpp>

namespace {

// Helper class to set up and tear down test surfaces
class SurfaceFixture {
public:
SurfaceFixture()
{
pugi::xml_document doc;
pugi::xml_node surf_node = doc.append_child("surface");
surf_node.set_name("surface");
surf_node.append_attribute("id") = "0";
surf_node.append_attribute("type") = "x-plane";
surf_node.append_attribute("coeffs") = "1";

for (int i = 1; i < 10; ++i) {
surf_node.attribute("id") = i;
openmc::model::surfaces.push_back(
std::make_unique<openmc::SurfaceXPlane>(surf_node));
openmc::model::surface_map[i] = i - 1;
}
}

~SurfaceFixture()
{
openmc::model::surfaces.clear();
openmc::model::surface_map.clear();
}
};

} // anonymous namespace

TEST_CASE("Test region simplification")
{
SurfaceFixture fixture;

SECTION("Original bug case from issue #3685")
{
// Input: "-1 2 (-3 4) | (-5 6)" was being incorrectly interpreted
auto region = openmc::Region("(-1 2 (-3 4) | (-5 6))", 0);
REQUIRE(region.str() == " ( ( -1 2 ( -3 4 ) ) | ( -5 6 ) )");
}

SECTION("Simple union - no extra parentheses needed")
{
auto region = openmc::Region("1 | 2", 0);
REQUIRE(region.str() == " 1 | 2");
}

SECTION("Intersection then union")
{
// Intersection should have higher precedence, so (1 2) grouped
auto region = openmc::Region("1 2 | 3", 0);
REQUIRE(region.str() == " ( 1 2 ) | 3");
}

SECTION("Union then intersection")
{
// The (2 3) intersection should be grouped
auto region = openmc::Region("1 | 2 3", 0);
REQUIRE(region.str() == " 1 | ( 2 3 )");
}

SECTION("Nested parentheses preserved")
{
// These parentheses are meaningful and should be preserved
auto region = openmc::Region("(1 | 2) (3 | 4)", 0);
REQUIRE(region.str() == " ( 1 | 2 ) ( 3 | 4 )");
}

SECTION("Deep nesting")
{
auto region = openmc::Region("((1 2) | (3 4)) 5", 0);
REQUIRE(region.str() == " ( ( 1 2 ) | ( 3 4 ) ) 5");
}

SECTION("Multiple unions")
{
auto region = openmc::Region("1 | 2 | 3", 0);
REQUIRE(region.str() == " 1 | 2 | 3");
}

SECTION("Multiple intersections")
{
auto region = openmc::Region("1 2 3", 0);
// Simple cell - no operators in output
REQUIRE(region.str() == " 1 2 3");
}

SECTION("Complex mixed expression")
{
auto region = openmc::Region("1 2 | 3 4 | 5 6", 0);
REQUIRE(region.str() == " ( 1 2 ) | ( 3 4 ) | ( 5 6 )");
}
}
Loading