Skip to content

Commit dd2b98a

Browse files
authored
New implementation of _sort_field_names based on Python list sorting and unit test (#682)
1 parent 71df1a5 commit dd2b98a

File tree

2 files changed

+108
-133
lines changed

2 files changed

+108
-133
lines changed

packages/seacas/scripts/exomerge3.py

Lines changed: 58 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -473,31 +473,27 @@ class ExodusModel(object):
473473
ELEMENT_ORDER["line3"] = 2
474474
ELEMENT_ORDER["point"] = 1
475475

476-
# define components of multi-component fields
477-
MULTI_COMPONENT_FIELD_SUBSCRIPTS = dict()
478-
MULTI_COMPONENT_FIELD_SUBSCRIPTS["vector"] = ("x", "y", "z")
479-
MULTI_COMPONENT_FIELD_SUBSCRIPTS["symmetric_3x3_tensor"] = (
480-
"xx",
481-
"yy",
482-
"zz",
483-
"xy",
484-
"yz",
485-
"zx",
486-
)
487-
MULTI_COMPONENT_FIELD_SUBSCRIPTS["full_3x3_tensor"] = (
488-
"xx",
489-
"yy",
490-
"zz",
491-
"xy",
492-
"yz",
493-
"zx",
494-
"yx",
495-
"zy",
496-
"xz",
497-
)
498-
ALL_MULTI_COMPONENT_FIELD_SUBSCRIPTS = set(
499-
itertools.chain(*list(MULTI_COMPONENT_FIELD_SUBSCRIPTS.values()))
500-
)
476+
# A dictionary defining the order of components in multi-component fields.
477+
# See "_sort_field_names" method for details.
478+
_FIELD_NAME_SUBSCRIPT_ORDER = {
479+
"xx": 1,
480+
"yy": 2,
481+
"zz": 3,
482+
"xy": 4,
483+
"yz": 5,
484+
"zx": 6,
485+
"yx": 7,
486+
"zy": 8,
487+
"xz": 9,
488+
"x": 10,
489+
"y": 11,
490+
"z": 12,
491+
}
492+
493+
# Regular expression used to parse field names. It splits the name into three named groups: base_name, component, and integration_point.
494+
# See "_sort_field_names" method for details.
495+
_FIELD_NAME_REGEX = re.compile(fr"^(?P<base_name>.*?)(?:[_]?)(?P<component>{'|'.join(_FIELD_NAME_SUBSCRIPT_ORDER.keys())})?(?:[_]?(?P<integration_point>\d+))?$")
496+
501497

502498
def __init__(self):
503499
"""Initialize the model."""
@@ -6915,123 +6911,54 @@ def create_timestep(self, timestep):
69156911
for name, values in list(self.global_variables.items()):
69166912
values.insert(timestep_index, self._get_default_field_value(name))
69176913

6918-
def _replace_name_case(self, new_list, original_list):
6919-
"""
6920-
Return the lowercase version of all strings in the given list.
6921-
6922-
Example:
6923-
>>> model._replace_name_case(['x', 'z', 'fred'], ['X', 'Fred', 'Z'])
6924-
['X', 'Z', 'Fred']
6925-
6926-
"""
6927-
original_case = dict((x.lower(), x) for x in original_list)
6928-
if len(original_case) != len(original_list):
6929-
self._warning(
6930-
"Ambiguous string case.",
6931-
"There are multiple strings in the list which have "
6932-
"identical lowercase representations. One will be "
6933-
"chosen at random.",
6934-
)
6935-
for item in new_list:
6936-
if item.lower() not in original_case:
6937-
self._bug(
6938-
"Unrecognized string.",
6939-
'The string "%s" appears in the new list but '
6940-
"not in the original list." % item,
6941-
)
6942-
return [original_case[x.lower()] for x in new_list]
6943-
6944-
def _sort_field_names(self, original_field_names):
6914+
def _sort_field_names(self, original_field_names: list[str]) -> list[str]:
69456915
"""
69466916
Return field names sorted in a SIERRA-friendly manner.
69476917
69486918
In order for SIERRA to recognize vectors, tensors, and element fields
69496919
with multiple integration points, fields must be sorted in a specific
6950-
order. This function provides that sort order.
6920+
order. This function provides that sort order.
69516921
69526922
As fields within exomerge are stored in a set, exomerge has no internal
6953-
or natural field order. This routine is only necessary for writing to
6923+
or natural field order. This routine is only necessary for writing to
69546924
ExodusII files.
69556925
6926+
This method recognizes the following field naming patterns:
6927+
6928+
- <base_name>_<component>_<integration_point>. E.g. "unrotated_stress_xx_1"
6929+
- <base_name>_<integration_point>. E.g. "ln_strain_1"
6930+
- <base_name>_<component>. E.g. "Displacement_X" or "SIGMA_XX"
6931+
- <base_name>. E.g. "temperature"
6932+
6933+
Same patterns but omitting the underscore are also recognized:
6934+
6935+
- <base_name><component><integration_point>. E.g. "unrotated_stressxx1"
6936+
- <base_name><integration_point>. E.g. "ln_strain1"
6937+
- <base_name><component>. E.g. "DisplacementX" or "SIGMAXX"
6938+
- <base_name>. E.g. "temperature"
6939+
6940+
The sorting is done by the base name (alphabetically), then by the
6941+
integration point (1, 2, 3), and finally by the component
6942+
(according to the "_FIELD_NAME_SUBSCRIPT_ORDER" dictionary).
69566943
"""
6957-
field_names = [x.lower() for x in original_field_names]
6958-
# Look through all fields to find multi-component fields and store
6959-
# these as tuples of the following form:
6960-
# ('base_name', 'component', integration_points)
6961-
multicomponent_fields = set()
6962-
for name in field_names:
6963-
# see if it has an integration point
6964-
if re.match(".*_[0-9]+$", name):
6965-
(name, integration_point) = name.rsplit("_", 1)
6966-
integration_point = int(integration_point)
6967-
else:
6968-
integration_point = None
6969-
# see if it possibly has a component
6970-
if re.match(".*_.+$", name):
6971-
component = name.rsplit("_", 1)[1]
6972-
if component in self.ALL_MULTI_COMPONENT_FIELD_SUBSCRIPTS:
6973-
name = name.rsplit("_", 1)[0]
6974-
multicomponent_fields.add((name, component, integration_point))
6975-
# now sort multi-component fields
6976-
base_names = set(x for x, _, _ in multicomponent_fields)
6977-
sorted_field_names = dict()
6978-
field_names = set(field_names)
6979-
for base_name in base_names:
6980-
# find all components of this form
6981-
components = set(
6982-
x for name, x, _ in multicomponent_fields if name == base_name
6983-
)
6984-
# find max integration point value
6985-
integration_points = set(
6986-
x
6987-
for name, _, x in multicomponent_fields
6988-
if name == base_name and x is not None
6989-
)
6990-
if integration_points:
6991-
integration_point_count = max(
6992-
x
6993-
for name, _, x in multicomponent_fields
6994-
if name == base_name and x is not None
6995-
)
6996-
else:
6997-
integration_point_count = None
69986944

6999-
# see if the components match the form of something
7000-
matching_form = None
7001-
for form, included_components in list(
7002-
self.MULTI_COMPONENT_FIELD_SUBSCRIPTS.items()
7003-
):
7004-
if set(included_components) == components:
7005-
matching_form = form
7006-
if not matching_form:
7007-
continue
7008-
# see if all components and integration points are present
7009-
mid = [
7010-
"_" + x for x in self.MULTI_COMPONENT_FIELD_SUBSCRIPTS[matching_form]
7011-
]
7012-
if integration_point_count is None:
7013-
last = [""]
7014-
else:
7015-
last = ["_" + str(x + 1) for x in range(integration_point_count)]
7016-
all_names = [base_name + m + s for s in last for m in mid]
7017-
if set(all_names).issubset(field_names):
7018-
sorted_field_names[all_names[0]] = all_names
7019-
field_names = field_names - set(all_names)
7020-
# sort field names which are not part of multicomponent fields
7021-
field_names = sorted(field_names)
7022-
# for each list of field names, find place to splice into list
7023-
place_to_insert = dict()
7024-
for name in list(sorted_field_names.keys()):
7025-
place = bisect.bisect_left(field_names, name)
7026-
if place not in place_to_insert:
7027-
place_to_insert[place] = [name]
7028-
else:
7029-
place_to_insert[place].append(name)
7030-
# splice them in
7031-
for place in sorted(list(place_to_insert.keys()), reverse=True):
7032-
for name in place_to_insert[place]:
7033-
field_names[place:place] = sorted_field_names[name]
7034-
return self._replace_name_case(field_names, original_field_names)
6945+
def _sorting_key(elem: str) -> tuple[str, int, int]:
6946+
"""This inner function transforms each element of the list "original_field_names"
6947+
into another element that will be used for sorting purposes
6948+
"""
6949+
6950+
match = self._FIELD_NAME_REGEX.match(elem.lower()).groupdict() # type: ignore
6951+
6952+
base_name = str(match["base_name"])
6953+
integration_point = int(match["integration_point"]) if match["integration_point"] is not None else 0
6954+
6955+
# Transform the component to a letter according to the _FIELD_NAME_SUBSCRIPT_ORDER
6956+
component = self._FIELD_NAME_SUBSCRIPT_ORDER[match["component"]] if match["component"] is not None else 0
6957+
6958+
return (base_name, integration_point, component)
6959+
6960+
original_field_names.sort(key=_sorting_key)
6961+
return original_field_names
70356962

70366963
def _reorder_list(self, the_list, new_index):
70376964
"""

packages/seacas/scripts/tests/exomerge_unit_test.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def _topology_test(self):
603603
# Tests should return None if successful (no return statement needed)
604604
# Tests should return False if the test was unable to be run.
605605
# Tests should raise an exception or exit(1) if unsuccessful.
606-
606+
607607
def _test_calculate_element_volumes(self):
608608
ids = self.model._get_standard_element_block_ids()
609609
if not ids:
@@ -1874,8 +1874,47 @@ def test(self):
18741874
print("\nSuccess")
18751875

18761876

1877+
# The following functions are unit tests for private functions of exomerge.
1878+
def _test_sort_field_names(self):
1879+
"""Unittest for _sort_field_names method.
1880+
1881+
In this test, we will create a list of field names that are sorted according to
1882+
SIERRA conventions, then randomly shuffle them to simulate unsorted input.
1883+
1884+
Both naming conventions with and without underscores will be tested.
1885+
"""
1886+
1887+
# List of all possible field names sorted according to SIERRA conventions.
1888+
sorted_names = [
1889+
"Displacement_X", "Displacement_Y", "Displacement_Z",
1890+
"ln_strain_1", "ln_strain_2", "ln_strain_3", "ln_strain_4", # scalar field defined in integration points
1891+
"SIGMA_XX", "SIGMA_YY", "SIGMA_ZZ", "SIGMA_XY", "SIGMA_YZ", "SIGMA_ZX", "SIGMA_YX", "SIGMA_ZY", "SIGMA_XZ", # asymmetric tensor
1892+
"unrotated_stress_xx_1", "unrotated_stress_yy_1", "unrotated_stress_zz_1", "unrotated_stress_xy_1", "unrotated_stress_yz_1", "unrotated_stress_zx_1", # Symmetric tensor with integration points
1893+
"unrotated_stress_xx_2", "unrotated_stress_yy_2", "unrotated_stress_zz_2", "unrotated_stress_xy_2", "unrotated_stress_yz_2", "unrotated_stress_zx_2",
1894+
"unrotated_stress_xx_3", "unrotated_stress_yy_3", "unrotated_stress_zz_3", "unrotated_stress_xy_3", "unrotated_stress_yz_3", "unrotated_stress_zx_3",
1895+
"unrotated_stress_xx_12", "unrotated_stress_yy_12", "unrotated_stress_zz_12", "unrotated_stress_xy_12", "unrotated_stress_yz_12", "unrotated_stress_zx_12", # Try with a number bigger than 9
1896+
"velocity" # scalar field
1897+
]
1898+
1899+
# Randomly shuffle the names to simulate unsorted input
1900+
unsorted_names = sorted_names.copy()
1901+
random.shuffle(unsorted_names)
1902+
assert sorted_names == self.model._sort_field_names(unsorted_names), "Failed to sort names with underscores.\nExpected: {}\nGot: {}".format(
1903+
sorted_names, self.model._sort_field_names(unsorted_names)
1904+
)
1905+
1906+
# Test sorting names without underscores
1907+
sorted_names_no_underscores = [name.replace("_", "") for name in sorted_names]
1908+
unsorted_names_no_underscores = sorted_names_no_underscores.copy()
1909+
random.shuffle(unsorted_names_no_underscores)
1910+
assert sorted_names_no_underscores == self.model._sort_field_names(unsorted_names_no_underscores), "Failed to sort names without underscores. \nExpected: {}\nGot: {}".format(
1911+
sorted_names_no_underscores, self.model._sort_field_names(unsorted_names_no_underscores)
1912+
)
1913+
1914+
18771915
# if this module is executed (as opposed to imported), run the tests
1878-
if __name__ == "__main__":
1916+
if __name__ == "__main__":
1917+
18791918
if len(sys.argv) > 2:
18801919
sys.stderr.write("Invalid syntax.\n")
18811920
exit(1)
@@ -1885,3 +1924,12 @@ def test(self):
18851924
tester.min_tests = int(sys.argv[1])
18861925
tester.max_tests = tester.min_tests
18871926
tester.test()
1927+
1928+
# Run unittest for private functions
1929+
print("\nRunning unittest for private functions in exomerge.py...")
1930+
input_dir = os.path.dirname(__file__)
1931+
temp_exo_path = os.path.join(input_dir, "exomerge_unit_test.e")
1932+
tester = ExomergeUnitTester()
1933+
tester.model = exomerge.import_model(temp_exo_path)
1934+
print("[1]_test_sort_field_names")
1935+
tester._test_sort_field_names()

0 commit comments

Comments
 (0)