Skip to content

Commit 8e515d4

Browse files
nickreichclaude
andcommitted
Fix sarix sigma_pooling='shared' bug and update dependencies
This commit fixes a failing test for SARIX models with shared sigma pooling and updates the sarix dependency to resolve an upstream bug. Problem: -------- The test `test_sarix_shared_sigma_pooling_multiple_batches` was failing with a reshape error when using `sigma_pooling='shared'` with multiple locations. The error occurred in the sarix library at line 392 where it incorrectly used `theta` instead of `sigma` when reshaping arrays: TypeError: cannot reshape array of shape (100, 5, 6) into shape (100, 1, 1) Root Cause: ----------- The installed sarix package (v0.0.1 from elray1/sarix) contained a bug where the variable name was wrong in the sigma pooling code block. The reichlab/sarix repository had already fixed this bug in v0.2.0. Fixes: ------ 1. Updated sarix dependency from elray1/sarix to reichlab/sarix (v0.2.0) 2. Updated requires-python from >=3.9 to >=3.11 (required by newer sarix) 3. Regenerated uv.lock with updated dependencies 4. Regenerated requirements.txt and requirements-dev.txt 5. Added explicit string conversion for output_type_id in CSV output 6. Fixed test assertion to handle pandas type inference for output_type_id Test Results: ------------- ✅ test_sarix - PASSED ✅ test_sarix_shared_sigma_pooling_multiple_batches - PASSED (was failing) ✅ test_drop_level_feats - PASSED 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent f7f6556 commit 8e515d4

6 files changed

Lines changed: 64 additions & 1561 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "idmodels"
33
description = "A Python module for modeling infectious disease."
44
license = {text = "MIT License"}
55
readme = "README.md"
6-
requires-python = ">=3.9"
6+
requires-python = ">=3.11"
77
classifiers = [
88
"Programming Language :: Python :: 3",
99
"License :: OSI Approved :: MIT License",

requirements/requirements-dev.txt

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ botocore==1.35.36
2020
# via aiobotocore
2121
cfgv==3.4.0
2222
# via pre-commit
23-
colorama==0.4.6
24-
# via
25-
# pytest
26-
# tqdm
2723
contourpy==1.3.0
2824
# via matplotlib
2925
coverage==7.6.4
@@ -42,19 +38,19 @@ frozenlist==1.5.0
4238
# aiosignal
4339
fsspec==2024.10.0
4440
# via s3fs
45-
iddata @ git+https://github.com/reichlab/iddata@3ad0ac0dc6d7f14488628a49bfb5228ca0643e1b
41+
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
4642
# via idmodels (pyproject.toml)
4743
identify==2.6.1
4844
# via pre-commit
4945
idna==3.10
5046
# via yarl
5147
iniconfig==2.0.0
5248
# via pytest
53-
jax==0.4.35
49+
jax==0.8.0
5450
# via
5551
# numpyro
5652
# sarix
57-
jaxlib==0.4.35
53+
jaxlib==0.8.0
5854
# via
5955
# jax
6056
# numpyro
@@ -100,7 +96,7 @@ numpy==2.1.3
10096
# scikit-learn
10197
# scipy
10298
# timeseriesutils
103-
numpyro==0.15.3
99+
numpyro==0.19.0
104100
# via sarix
105101
opt-einsum==3.4.0
106102
# via jax
@@ -112,7 +108,6 @@ pandas==2.2.3
112108
# via
113109
# idmodels (pyproject.toml)
114110
# iddata
115-
# sarix
116111
# timeseriesutils
117112
pillow==11.0.0
118113
# via matplotlib
@@ -147,7 +142,7 @@ ruff==0.7.2
147142
# via idmodels (pyproject.toml)
148143
s3fs==2024.10.0
149144
# via iddata
150-
sarix @ git+https://github.com/elray1/sarix@1c8995942d49afbb66637f0dc0662e1248606af4
145+
sarix @ git+https://github.com/reichlab/sarix@cedaebed367d9fc00e2b34cdfc6db09eef6a99a3
151146
# via idmodels (pyproject.toml)
152147
scikit-learn==1.5.2
153148
# via idmodels (pyproject.toml)

requirements/requirements.txt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ attrs==24.2.0
1818
# pymmwr
1919
botocore==1.35.36
2020
# via aiobotocore
21-
colorama==0.4.6
22-
# via tqdm
2321
contourpy==1.3.0
2422
# via matplotlib
2523
cycler==0.12.1
@@ -32,15 +30,15 @@ frozenlist==1.5.0
3230
# aiosignal
3331
fsspec==2024.10.0
3432
# via s3fs
35-
iddata @ git+https://github.com/reichlab/iddata@3ad0ac0dc6d7f14488628a49bfb5228ca0643e1b
33+
iddata @ git+https://github.com/reichlab/iddata@c28849b2a02ab84e2f82876f16fee2ac60814877
3634
# via idmodels (pyproject.toml)
3735
idna==3.10
3836
# via yarl
39-
jax==0.4.35
37+
jax==0.8.0
4038
# via
4139
# numpyro
4240
# sarix
43-
jaxlib==0.4.35
41+
jaxlib==0.8.0
4442
# via
4543
# jax
4644
# numpyro
@@ -84,7 +82,7 @@ numpy==2.1.3
8482
# scikit-learn
8583
# scipy
8684
# timeseriesutils
87-
numpyro==0.15.3
85+
numpyro==0.19.0
8886
# via sarix
8987
opt-einsum==3.4.0
9088
# via jax
@@ -94,7 +92,6 @@ pandas==2.2.3
9492
# via
9593
# idmodels (pyproject.toml)
9694
# iddata
97-
# sarix
9895
# timeseriesutils
9996
pillow==11.0.0
10097
# via matplotlib
@@ -117,7 +114,7 @@ rich==13.9.4
117114
# via iddata
118115
s3fs==2024.10.0
119116
# via iddata
120-
sarix @ git+https://github.com/elray1/sarix@1c8995942d49afbb66637f0dc0662e1248606af4
117+
sarix @ git+https://github.com/reichlab/sarix@cedaebed367d9fc00e2b34cdfc6db09eef6a99a3
121118
# via idmodels (pyproject.toml)
122119
scikit-learn==1.5.2
123120
# via idmodels (pyproject.toml)

src/idmodels/sarix.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ def run(self, run_config):
106106
run_config=run_config,
107107
model_config=self.model_config
108108
)
109+
# Ensure output_type_id is string to avoid pandas inferring it as float when reading
110+
preds_df["output_type_id"] = preds_df["output_type_id"].astype(str)
109111
preds_df.to_csv(save_path, index=False)
110112

111113

tests/integration/test_sarix.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,8 @@ def test_sarix_shared_sigma_pooling_multiple_batches(tmp_path):
136136
"Output should contain predictions for all input locations"
137137
assert all(actual_df["output_type"] == "quantile"), \
138138
"All outputs should be quantiles"
139-
assert set(actual_df["output_type_id"].unique()) == set(run_config.q_labels), \
139+
# Convert output_type_id to string for comparison since pandas may infer numeric types
140+
assert set(actual_df["output_type_id"].astype(str).unique()) == set(run_config.q_labels), \
140141
"Output should contain all specified quantile levels"
141142
assert actual_df["value"].notna().all(), \
142143
"All predictions should be non-null"

0 commit comments

Comments
 (0)