Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 21 additions & 17 deletions GeneticInheritanceGraphLibrary/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class IndividualTableRow(

class BaseTable:
_RowClass = None
_non_int_fieldtypes = {}
_non_int64_fieldtypes = {} # By default all fields are int64
initial_size = 64 # default
max_resize = 2**18 # maximum number of rows by which we expand internal storage
_frozen = None # Will be overridden during init
Expand All @@ -106,7 +106,7 @@ def _create_datastore(self):
self._datastore = np.empty(
self.initial_size,
dtype=[
(name, self._non_int_fieldtypes.get(name, np.int64))
(name, self._non_int64_fieldtypes.get(name, np.int64))
for name in self._RowClass._fields
],
)
Expand Down Expand Up @@ -259,7 +259,7 @@ def _create_datastore(self):
self._datastore = np.empty(
self.initial_size,
dtype=[
(name, self._non_int_fieldtypes.get(name, np.int64))
(name, self._non_int64_fieldtypes.get(name, np.int64))
for name in self._RowClass._fields
if name not in self._extra_names
],
Expand Down Expand Up @@ -295,6 +295,10 @@ class IEdgeTable(BaseTable):
"""

_RowClass = IEdgeTableRow
_non_int64_fieldtypes = {
"child_chromosome": np.int16, # Save some space
"parent_chromosome": np.int16, # Save some space
}

# define each property by hand, for speed
@property
Expand Down Expand Up @@ -464,22 +468,22 @@ def add_row(
# need to check the values before they were put into the data array,
# as numpy silently converts floats to integers on assignment
if validate & ValidFlags.IEDGES_INTEGERS:
for is_edge, i in enumerate(
for i, val in enumerate(
(
edge,
child_left,
child_right,
parent_left,
parent_right,
child,
parent,
child_chromosome,
parent_chromosome,
child_left, # 0
child_right, # 1
parent_left, # 2
parent_right, # 3
child, # 4
parent, # 5
child_chromosome, # 6
parent_chromosome, # 7
edge, # 8
)
):
if int(i) != i:
if int(val) != val:
raise ValueError("Iedge data must be integers")
if is_edge != 0 and i < 0:
if i < 6 and val < 0:
Comment on lines +471 to +486
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation logic in add_row method of IEdgeTable checks if the provided values are integers and non-negative (except for the edge ID). This is crucial for data integrity. However, consider extracting this validation logic into a separate method to improve readability and maintainability.

-        if validate is None:
-            validate = ~ValidFlags.IEDGES_ALL
-        # only try validating if any IEDGES flags are set
-        if (not skip_validate) and bool(validate & ValidFlags.IEDGES_ALL):
-            # need to check the values before they were put into the data array,
-            # as numpy silently converts floats to integers on assignment
-            if validate & ValidFlags.IEDGES_INTEGERS:
-                for i, val in enumerate(
-                    (
-                        child_left,  # 0
-                        child_right,  # 1
-                        parent_left,  # 2
-                        parent_right,  # 3
-                        child,  # 4
-                        parent,  # 5
-                        child_chromosome,  # 6
-                        parent_chromosome,  # 7
-                        edge,  # 8
-                    )
-                ):
-                    if int(val) != val:
-                        raise ValueError("Iedge data must be integers")
-                    if i < 6 and val < 0:
-                        raise ValueError(
-                            "Iedge data must be non-negative (except edge ID)"
-                        )
+        self._validate_iedge_data(validate, skip_validate, child_left, child_right, parent_left, parent_right, child, parent, child_chromosome, parent_chromosome, edge)

Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.

Suggested change
for i, val in enumerate(
(
edge,
child_left,
child_right,
parent_left,
parent_right,
child,
parent,
child_chromosome,
parent_chromosome,
child_left, # 0
child_right, # 1
parent_left, # 2
parent_right, # 3
child, # 4
parent, # 5
child_chromosome, # 6
parent_chromosome, # 7
edge, # 8
)
):
if int(i) != i:
if int(val) != val:
raise ValueError("Iedge data must be integers")
if is_edge != 0 and i < 0:
if i < 6 and val < 0:

raise ValueError(
"Iedge data must be non-negative (except edge ID)"
)
Expand Down Expand Up @@ -708,7 +712,7 @@ class NodeTable(BaseExtraTable):
"""

_RowClass = NodeTableRow
_non_int_fieldtypes = {"time": np.float64, "flags": np.uint32}
_non_int64_fieldtypes = {"time": np.float64, "flags": np.uint32}

# define each property by hand, for speed
@property
Expand Down Expand Up @@ -758,7 +762,7 @@ def append(self, obj) -> int:

class IndividualTable(BaseExtraTable):
_RowClass = IndividualTableRow
_non_int_fieldtypes = {"flags": np.uint32}
_non_int64_fieldtypes = {"flags": np.uint32}
_extra_names = ["location", "parents", "metadata"]

@property
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ portion
pytest
pytest-cov
pytest-xdist
tqdm
tskit
msprime
numpy
Expand Down
25 changes: 23 additions & 2 deletions tests/gigutil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import GeneticInheritanceGraphLibrary as gigl
import numpy as np
from tqdm.auto import tqdm

# Utilities for creating and editing gigs

Expand Down Expand Up @@ -168,6 +169,7 @@ def run(
random_seed=None,
initial_node_flags=None,
further_node_flags=None,
progress_monitor=None,
):
"""
Initialise and run a new population for a given number of generations. The last
Expand All @@ -190,6 +192,8 @@ def run(
gens, num_diploids[-gens - 1], node_flags=initial_node_flags
)
# First generation by hand, so that we can specify the sequence length
if progress_monitor:
progress_monitor = tqdm(total=gens, desc="Simulating", unit="gen")
gens -= 1
self.new_population(
gens,
Expand All @@ -204,14 +208,18 @@ def run(
self.new_population(
gens, size=num_diploids[-gens - 1], node_flags=further_node_flags
)
if progress_monitor:
progress_monitor.update(1)

self.tables.sort()
# We should probably simplify or at least sample_resolve here?
# We should also mark gen 0 as samples and unmark the others.
# Probably a parameter `simplify` would be useful?
return self.tables.copy().graph()

def run_more(self, num_diploids, seq_len, gens, random_seed=None):
def run_more(
self, num_diploids, seq_len, gens, random_seed=None, progress_monitor=None
):
"""
The num_diploids parameter can be an array of length `gens` giving the diploid
population size in each generation.
Expand All @@ -225,9 +233,13 @@ def run_more(self, num_diploids, seq_len, gens, random_seed=None):

# augment the generations
self.tables.change_times(gens)
if progress_monitor:
progress_monitor = tqdm(total=gens, desc="Simulating more", unit="gen")
while gens > 0:
gens -= 1
self.new_population(gens, size=num_diploids[-gens - 1])
if progress_monitor:
progress_monitor.update(1)

self.tables.sort()
return self.tables.copy().graph()
Expand Down Expand Up @@ -347,13 +359,16 @@ def run(
random_seed=None,
initial_node_flags=None,
further_node_flags=None,
progress_monitor=None,
):
"""
The num_diploids param can be an array of length `gens + 1` giving the diploid
population size in each generation. This allows quick growth of a population
"""
self.rng = np.random.default_rng(random_seed)
self.num_tries_for_breakpoint = 20 # number of tries to find a breakpoint
if progress_monitor:
progress_monitor = tqdm(total=gens, desc="Simulating", unit="gen")
if isinstance(num_diploids, int):
num_diploids = [num_diploids] * (gens + 1)
self.pop = self.initialise_population(
Expand All @@ -364,9 +379,11 @@ def run(
self.new_population(
gens, size=num_diploids[-gens - 1], node_flags=further_node_flags
)
if progress_monitor:
progress_monitor.update(1)
return self.tables_to_gig_without_grand_mrca()

def run_more(self, num_diploids, gens, random_seed=None):
def run_more(self, num_diploids, gens, random_seed=None, progress_monitor=None):
"""
The num_diploids parameter can be an array of length `gens` giving the diploid
population size in each generation.
Expand All @@ -378,9 +395,13 @@ def run_more(self, num_diploids, gens, random_seed=None):
if isinstance(num_diploids, int):
num_diploids = [num_diploids] * (gens)
self.tables.change_times(gens)
if progress_monitor:
progress_monitor = tqdm(total=gens, desc="Simulating more", unit="gen")
while gens > 0:
gens -= 1
self.new_population(gens, size=num_diploids[-gens - 1])
if progress_monitor:
progress_monitor.update(1)
return self.tables_to_gig_without_grand_mrca()

def find_comparable_points(self, tables, parent_nodes):
Expand Down
Loading