Skip to content

Commit 7c36b10

Browse files
committed
[multi_tag] Add docstrings and test and fix bugs for sort np.array
1 parent e6919e1 commit 7c36b10

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

nixworks/multi_tag.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,22 @@ def _in_range(point, start, end):
3939

4040

4141
def _sorting(starts, ends): # li is the start values
42+
starts = starts.tolist()
43+
ends = ends.tolist()
4244
sort = [i for i in sorted(enumerate(starts), key=lambda s: s[1])]
4345
sorted_starts = np.array([s[1] for s in sort])
4446
sorted_ends = np.array([ends[s[0]] for s in sort])
4547
return sorted_starts, sorted_ends
4648

4749

4850
def union(ref, multi_tags):
49-
# now the simple case of 2 tags
51+
"""
52+
Function to return the (non-overlapping) union of area tagged by multiple Tags
53+
or MultiTags of a specified DataArray.
54+
:param ref: the referenced array
55+
:param multi_tags: Tags or MultiTags that point to the tagged data
56+
:return: a list of DataViews
57+
"""
5058
_check_valid(multi_tags, ref)
5159
if not isinstance(ref, nix.DataArray):
5260
ref = multi_tags[0].references[ref]
@@ -56,7 +64,7 @@ def union(ref, multi_tags):
5664
end_list = []
5765
for i, st in enumerate(starts): # check if any duplicate
5866
covered = False
59-
for ti, tmp_st, tmp_ed in enumerate(zip(start_list, end_list)):
67+
for ti, (tmp_st, tmp_ed) in enumerate(zip(start_list, end_list)):
6068
if _in_range(st, tmp_st, tmp_ed) or _in_range(ends[i], tmp_st, tmp_ed):
6169
covered = True
6270
if not _in_range(ends[i], tmp_st, tmp_ed): # ends[i] > tmp_ed
@@ -75,6 +83,12 @@ def union(ref, multi_tags):
7583

7684

7785
def intersection(ref, multi_tags):
86+
"""
87+
Function to return the overlapping area in a specified DataArray tagged by multiple Tags/MultiTags.
88+
:param ref: the referenced array
89+
:param multi_tags: Tags or MultiTags that point to the tagged data
90+
:return: a DataView
91+
"""
7892
_check_valid(multi_tags, ref)
7993
if not isinstance(ref, nix.DataArray):
8094
ref = multi_tags[0].references[ref]

nixworks/test/test_multi_tag.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,16 +37,21 @@ def test_two_1d(self):
3737
t2.extent = [5]
3838
i = intersection(self.ref1d, [t1,t2])
3939
assert i is None
40+
u = union(self.ref1d, [t1,t2])
41+
np.testing.assert_array_equal(u[0], self.arr1d[0:6])
42+
np.testing.assert_array_equal(u[1], self.arr1d[10:16])
4043
# intersected
4144
t1.extent = [12]
4245
i = intersection(self.ref1d, [t1,t2])
4346
np.testing.assert_array_almost_equal(np.array(i), self.arr1d[10:13])
47+
u = union(self.ref1d, [t1,t2])
48+
np.testing.assert_array_equal(u[0], self.arr1d[0:16])
4449
# covered
4550
t1.extent = [30]
4651
i = intersection(self.ref1d, [t1,t2])
4752
np.testing.assert_array_almost_equal(np.array(i), t2.tagged_data(0)[:])
48-
# union
49-
# u = union(self.ref1d, [t1,t2])
53+
u = union(self.ref1d, [t1,t2])
54+
np.testing.assert_array_equal(u[0], t1.tagged_data(0)[:])
5055

5156
def test_multi_nd(self):
5257
d = np.zeros((2, 3))
@@ -68,5 +73,7 @@ def test_multi_nd(self):
6873
i = intersection(self.ref3d, [t1,t2])
6974
np.testing.assert_array_almost_equal(i[:], self.arr3d[3,3,3])
7075
# union
71-
# u = union(self.ref3d, [t1])
72-
76+
u = union(self.ref3d, [t1])
77+
np.testing.assert_array_almost_equal(u[0][:], self.arr3d[0:5, 0:5, 0:5])
78+
u = union(self.ref3d, [t1, t2])
79+
np.testing.assert_array_almost_equal(u[0][:], self.arr3d[0:6, 0:6, 0:6])

0 commit comments

Comments
 (0)