Skip to content

Traversal tests, bug fix #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyxlma/lmalib/traversal.py
Original file line number Diff line number Diff line change
@@ -222,6 +222,8 @@ def reduce_to_entities(self, entity_id_var, entity_ids):
e_iter = (np.atleast_1d(p_group[eid]) for eid in last_entity_ids
if eid in p_group)
e_idx = list(itertools.chain.from_iterable(e_iter))
if len(e_idx) > 0 and type(e_idx[0]) == slice:
e_idx = np.array([e.start for e in e_idx])
if len(e_idx) == 0:
# xarray doesn't accept an empty array as a valid index
e_idx = slice(0, 0)
@@ -242,6 +244,8 @@ def reduce_to_entities(self, entity_id_var, entity_ids):
e_iter = (np.atleast_1d(e_group[eid]) for eid in entity_ids
if eid in e_group)
e_idx = list(itertools.chain.from_iterable(e_iter))
if len(e_idx) > 0 and type(e_idx[0]) == slice:
e_idx = np.array([e.start for e in e_idx])
last_entity_ids = entity_ids # == dataset[e_var].data
if len(e_idx) == 0:
# xarray doesn't accept an empty array as a valid index
172 changes: 172 additions & 0 deletions tests/test_traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
import numpy as np
from numpy.testing import assert_equal
import xarray as xr

from pyxlma.lmalib.traversal import OneToManyTraversal

def get_four_level_data():
storm_id = [0,1,2]
flash_parent_storm_id = [0,0,0,2,2,2,2,2]
flash_id = [1,2,3,4,5,6,7,8]
stroke_parent_flash_id = [1,1,2,3, 4, 4, 4, 6, 8, 8, 8]
stroke_id = [4,6,7,9,13,14,15,19,20,23,46]
trig_parent_stroke_id = [4,4,4,4,6, 7, 7, 9,13,13,14,14,15,19,20,20,23,46]
trig_id = [1,3,5,8,9,10,12,16,18,19,20,22,23,25,26,30,31,32]
trig_parent_storm_id = [0,0,0,0,0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
trig_parent_flash_id = [1,1,1,1,1, 2, 2, 3, 4, 4, 4, 4, 4, 6, 8, 8, 8, 8]

storm_child_flash_count = [3,0,5]
flash_child_stroke_count = [2,1,1,3,0,1,0,3]
stroke_child_trig_count = [4,1,2,1,2,2,1,1,2,1,1]

storm_child_stroke_count = [4,0,7]
storm_child_trig_count = [8,0,10]

storm_dim = 'nstorms'
flash_dim = 'nflashes'
stroke_dim = 'nstrokes'
trig_dim = 'ntrigs'

d = xr.Dataset({
'storm_id': (storm_dim, storm_id),
'flash_id': (flash_dim, flash_id),
'flash_parent_storm_id': (flash_dim, flash_parent_storm_id),
'stroke_id': (stroke_dim, stroke_id),
'stroke_parent_flash_id': (stroke_dim, stroke_parent_flash_id),
'trig_id': (trig_dim, trig_id),
'trig_parent_stroke_id': (trig_dim, trig_parent_stroke_id),
'trig_parent_flash_id': (trig_dim, trig_parent_flash_id),
'trig_parent_storm_id': (trig_dim, trig_parent_storm_id),
'storm_child_flash_count': (storm_dim, storm_child_flash_count),
'storm_child_stroke_count': (storm_dim, storm_child_stroke_count),
'storm_child_trig_count': (storm_dim, storm_child_trig_count),
'flash_child_stroke_count': (flash_dim, flash_child_stroke_count),
'stroke_child_trig_count': (stroke_dim, stroke_child_trig_count),
})
d = d.set_coords(['stroke_id', 'flash_id', 'storm_id', 'trig_id',
'stroke_parent_flash_id', 'trig_parent_stroke_id',
'flash_parent_storm_id'])
assert len(flash_id) == len(flash_parent_storm_id)
assert len(stroke_id) == len(stroke_parent_flash_id)
assert len(trig_id) == len(trig_parent_stroke_id)
assert sum(storm_child_flash_count) == len(flash_id)
assert sum(storm_child_stroke_count) == len(stroke_id)
assert sum(storm_child_trig_count) == len(trig_id)
assert sum(flash_child_stroke_count) == len(stroke_id)
assert sum(stroke_child_trig_count) == len(trig_id)
return d

def get_four_level_data_traversal():
d = get_four_level_data()

entity_vars = ('storm_id', 'flash_id', 'stroke_id', 'trig_id')
parent_vars = ('flash_parent_storm_id',
'stroke_parent_flash_id',
'trig_parent_stroke_id')
traversal = OneToManyTraversal(d, entity_vars, parent_vars)
return d, traversal

def get_four_level_empty_data_traversal():
d = get_four_level_data()
empty_all = {dim:slice(0,0) for dim in d.dims}
d = d[empty_all]

entity_vars = ('storm_id', 'flash_id', 'stroke_id', 'trig_id')
parent_vars = ('flash_parent_storm_id',
'stroke_parent_flash_id',
'trig_parent_stroke_id')
traversal = OneToManyTraversal(d, entity_vars, parent_vars)
return d, traversal

def test_empty_traversal_init():
d, traversal = get_four_level_empty_data_traversal()

def test_count_children():
d, traversal = get_four_level_data_traversal()

# validation data
storm_child_count = d['storm_child_flash_count'].data
flash_child_count = d['flash_child_stroke_count'].data
stroke_child_count = d['stroke_child_trig_count'].data
storm_child_stroke_count = d['storm_child_stroke_count'].data
storm_child_trig_count = d['storm_child_trig_count'].data

n_storms = traversal.count_children('storm_id')[0]
assert_equal(storm_child_count, n_storms)
n_flashes = traversal.count_children('flash_id')[0]
assert_equal(flash_child_count, n_flashes)
n_strokes = traversal.count_children('stroke_id')[0]
assert_equal(stroke_child_count, n_strokes)

all_counts = traversal.count_children('storm_id', 'trig_id')
assert_equal(storm_child_count, all_counts[0])
assert_equal(flash_child_count, all_counts[1])
assert_equal(stroke_child_count, all_counts[2])

grouper = d.groupby('trig_parent_storm_id').groups
count = [len(grouper[eid]) if (eid in grouper) else 0
for eid in d['storm_id'].data]
assert_equal(storm_child_trig_count, count)

def test_replicate_parent_ids():
d, traversal = get_four_level_data_traversal()
trig_parent_storm_ids = traversal.replicate_parent_ids('storm_id',
'trig_parent_stroke_id')
trig_parent_flash_ids = traversal.replicate_parent_ids('flash_id',
'trig_parent_stroke_id')
trig_parent_stroke_ids = traversal.replicate_parent_ids('stroke_id',
'trig_parent_stroke_id')

assert_equal(d['trig_parent_storm_id'].data, trig_parent_storm_ids)
assert_equal(d['trig_parent_flash_id'].data, trig_parent_flash_ids)
assert_equal(d['trig_parent_stroke_id'].data, trig_parent_stroke_ids)


def test_prune_from_middle():
d, traversal = get_four_level_data_traversal()

reduced_stroke_id = [13,15,23]
d = traversal.reduce_to_entities('stroke_id', reduced_stroke_id)
reduced_storm_id = [2,]
reduced_flash_id = [4,8]
reduced_trig_id = [18,19,23,31]
assert_equal(d['storm_id'].data, reduced_storm_id)
assert_equal(d['flash_id'].data, reduced_flash_id)
assert_equal(d['stroke_id'].data, reduced_stroke_id)
assert_equal(d['trig_id'].data, reduced_trig_id)

def test_prune_from_bottom():
d, traversal = get_four_level_data_traversal()

trig_idx = slice(7,10)
reduced_storm_id = np.unique(d['trig_parent_storm_id'][trig_idx].data)
reduced_flash_id = np.unique(d['trig_parent_flash_id'][trig_idx].data)
reduced_stroke_id = np.unique(d['trig_parent_stroke_id'][trig_idx].data)
reduced_trig_id = d['trig_id'][trig_idx].data
d = traversal.reduce_to_entities('trig_id', reduced_trig_id)
assert_equal(d['trig_id'].data, reduced_trig_id)
assert_equal(d['stroke_id'].data, reduced_stroke_id)
assert_equal(d['flash_id'].data, reduced_flash_id)
assert_equal(d['storm_id'].data, reduced_storm_id)

def test_prune_from_top():
d, traversal = get_four_level_data_traversal()
reduced_storm_id = [1,]
d = traversal.reduce_to_entities('storm_id', reduced_storm_id)
reduced_stroke_id = np.asarray([])
reduced_flash_id = np.asarray([])
reduced_trig_id = np.asarray([])
assert_equal(d['storm_id'], reduced_storm_id)
assert_equal(d['flash_id'], reduced_flash_id)
assert_equal(d['stroke_id'], reduced_stroke_id)
assert_equal(d['trig_id'], reduced_trig_id)

reduced_storm_id = [2,]
d = traversal.reduce_to_entities('storm_id', reduced_storm_id)
reduced_flash_id = [4,5,6,7,8]
reduced_stroke_id = [13,14,15,19,20,23,46]
reduced_trig_id = [18,19,20,22,23,25,26,30,31,32]
assert_equal(d['storm_id'].data, reduced_storm_id)
assert_equal(d['flash_id'].data, reduced_flash_id)
assert_equal(d['stroke_id'].data, reduced_stroke_id)
assert_equal(d['trig_id'].data, reduced_trig_id)