Skip to content

Commit e1e20ac

Browse files
committed
Address pr comments
1 parent 4ca864d commit e1e20ac

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

src/vivarium/framework/randomness/stream.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -274,15 +274,11 @@ def filter_for_probability(
274274
return population
275275

276276
# Check for null values in probabilities
277-
if isinstance(probability, float):
277+
if isinstance(probability, (float, int)):
278278
if np.isnan(probability):
279279
raise ValueError("Probabilities contain null values")
280-
elif isinstance(probability, pd.Series):
281-
if probability.isna().any():
282-
raise ValueError("Probabilities contain null values")
283-
elif not isinstance(probability, (int, np.integer)):
284-
# Handle lists, tuples, and numpy arrays (but skip int types)
285-
if np.any(np.isnan(probability)):
280+
else:
281+
if np.isnan(probability).any():
286282
raise ValueError("Probabilities contain null values")
287283

288284
if isinstance(population, pd.Index):

tests/framework/randomness/test_stream.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -275,21 +275,26 @@ def test_stream_rate_conversion_config(
275275
assert sim._randomness._rate_conversion_type == rate_conversion
276276

277277

278-
def test_filter_for_probability_error_with_null_values() -> None:
278+
@pytest.mark.parametrize(
279+
"probs",
280+
[
281+
[0.3, np.nan, 0.5, np.nan, 0.7],
282+
np.nan,
283+
pd.Series([0.2, 0.3, np.nan, 0.4, 0.5]),
284+
],
285+
)
286+
def test_filter_for_probability_error_with_null_values(
287+
probs: float | list[float] | pd.Series[float],
288+
) -> None:
279289
randomness_stream = RandomnessStream(
280290
"test", lambda: pd.Timestamp(2020, 1, 1), 1, IndexMap()
281291
)
282292
pop = pd.DataFrame({"age": [10, 11, 12, 13, 14], "id": [1, 2, 3, 4, 5]}).set_index("id")
283293
with pytest.raises(ValueError, match="Probabilities contain null values"):
284-
randomness_stream.filter_for_probability(pop, [0.3, np.nan, 0.5, np.nan, 0.7])
294+
randomness_stream.filter_for_probability(pop, probs)
285295

286296
with pytest.raises(ValueError, match="Probabilities contain null values"):
287-
randomness_stream.filter_for_probability(pop, np.nan)
297+
randomness_stream.filter_for_probability(pop, probs)
288298

289299
with pytest.raises(ValueError, match="Probabilities contain null values"):
290-
randomness_stream.filter_for_probability(
291-
pop, pd.Series([0.2, 0.3, np.nan, 0.4, 0.5], index=pop.index)
292-
)
293-
294-
# Doesn't raise
295-
randomness_stream.filter_for_probability(pop, [0.3, 0.4, 0.5, 0.6, 0.7])
300+
randomness_stream.filter_for_probability(pop, probs)

0 commit comments

Comments
 (0)