Skip to content

Commit f2cc032

Browse files
jgallowa07claude
andcommitted
Merge fix-ci-tests branch into main
This merge includes fixes for: - Python 3.9 compatibility with typing_extensions.Self - include_counts parameter default changed to False for backward compatibility - Parameter naming consistency (coef_lasso_shift → scale_coeff_lasso_shift) - Doctest floating point zero representation using .replace(-0.0, 0.0) All CI tests now pass across Python 3.9-3.11 on Ubuntu and macOS. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <[email protected]>
2 parents a40bf74 + 4df4418 commit f2cc032

File tree

4 files changed

+26
-26
lines changed

4 files changed

+26
-26
lines changed

multidms/biophysical.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def proximal_box_constraints(params, hyperparameters, *args, **kwargs):
354354
def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
355355
"""ADMM generalized lasso optimization."""
356356
(
357-
coef_lasso_shift,
357+
scale_coeff_lasso_shift,
358358
admm_niter,
359359
admm_tau,
360360
admm_mu,
@@ -368,7 +368,7 @@ def proximal_objective(Dop, params, hyperparameters, scaling=1.0):
368368
# see https://pyproximal.readthedocs.io/en/stable/index.html
369369
beta_ravel, shift_ravel = pyproximal.optimization.primal.LinearizedADMM(
370370
pyproximal.L2(b=beta_ravel),
371-
pyproximal.L1(sigma=scaling * coef_lasso_shift),
371+
pyproximal.L1(sigma=scaling * scale_coeff_lasso_shift),
372372
Dop,
373373
niter=admm_niter,
374374
tau=admm_tau,

multidms/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ class notes.
100100
a unique name based upon the number of data objects
101101
instantiated.
102102
include_counts : bool
103-
If True (default), expects 'pre_count' and 'post_count' columns in the
104-
input DataFrame and includes them in the data arrays. If False, these
105-
columns are not required and count data will not be available.
103+
If True, expects 'pre_count' and 'post_count' columns in the
104+
input DataFrame and includes them in the data arrays. If False (default),
105+
these columns are not required and count data will not be available.
106106
107107
Example
108108
-------
@@ -197,7 +197,7 @@ def __init__(
197197
assert_site_integrity=False,
198198
verbose=False,
199199
name=None,
200-
include_counts=True,
200+
include_counts=False,
201201
):
202202
"""See main class docstring."""
203203
# Check and initialize conditions attribute

multidms/model.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,13 +146,13 @@ class Model:
146146
features included. These are automatically updated each time you
147147
request the property.
148148
149-
>>> model.get_mutations_df() # doctest: +NORMALIZE_WHITESPACE
149+
>>> model.get_mutations_df().replace(-0.0, 0.0) # doctest: +NORMALIZE_WHITESPACE
150150
wts sites muts times_seen_a times_seen_b beta_a beta_b shift_b \
151151
mutation
152152
M1E M 1 E 1 3 0.0 0.0 0.0
153-
M1W M 1 W 1 0 0.0 -0.0 0.0
154-
G3P G 3 P 1 4 -0.0 -0.0 -0.0
155-
G3R G 3 R 1 2 -0.0 0.0 -0.0
153+
M1W M 1 W 1 0 0.0 0.0 0.0
154+
G3P G 3 P 1 4 0.0 0.0 0.0
155+
G3R G 3 R 1 2 0.0 0.0 0.0
156156
<BLANKLINE>
157157
predicted_func_score_a predicted_func_score_b
158158
mutation
@@ -202,8 +202,8 @@ class Model:
202202
Next, we fit the model with some chosen hyperparameters.
203203
204204
>>> model.fit(maxiter=10, lasso_shift=1e-5, warn_unconverged=False)
205-
>>> model.loss
206-
0.3483478119356665
205+
>>> model.loss # doctest: +ELLIPSIS
206+
0.348347811935666...
207207
208208
The model tunes its parameters in place, and the subsequent call to retrieve
209209
the loss reflects our models loss given its updated parameters.

multidms/model_collection.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737

3838

3939
PARAMETER_NAMES_FOR_PLOTTING = {
40-
"coef_lasso_shift": "Lasso Penalty",
40+
"scale_coeff_lasso_shift": "Lasso Penalty",
4141
}
4242

4343

@@ -350,7 +350,7 @@ def all_mutations(self) -> tuple:
350350
@lru_cache(maxsize=10)
351351
def split_apply_combine_muts(
352352
self,
353-
groupby=("dataset_name", "coef_lasso_shift"),
353+
groupby=("dataset_name", "scale_coeff_lasso_shift"),
354354
aggregate_func="mean",
355355
inner_merge_dataset_muts=True,
356356
query=None,
@@ -374,7 +374,7 @@ def split_apply_combine_muts(
374374
groupby : str or tuple of str or None, optional
375375
The attributes to group the fits by. If None, then group by all
376376
attributes except for the model, data, and step_loss attributes.
377-
The default is ("dataset_name", "coef_lasso_shift").
377+
The default is ("dataset_name", "scale_coeff_lasso_shift").
378378
aggregate_func : str or callable, optional
379379
The function to aggregate the mutational dataframes within each group.
380380
The default is "mean".
@@ -524,7 +524,7 @@ def add_validation_loss(self, test_data, overwrite=False):
524524
def get_conditional_loss_df(self, query=None):
525525
"""
526526
Return a long form dataframe with columns
527-
"dataset_name", "coef_lasso_shift",
527+
"dataset_name", "scale_coeff_lasso_shift",
528528
"split" ("training" or "validation"),
529529
"loss" (actual value), and "condition".
530530
@@ -541,7 +541,7 @@ def get_conditional_loss_df(self, query=None):
541541
if len(queried_fits) == 0:
542542
raise ValueError("invalid query, no fits returned")
543543

544-
id_vars = ["dataset_name", "coef_lasso_shift"]
544+
id_vars = ["dataset_name", "scale_coeff_lasso_shift"]
545545
value_vars = [
546546
c for c in queried_fits.columns if "loss" in c and c != "step_loss"
547547
]
@@ -559,7 +559,7 @@ def get_conditional_loss_df(self, query=None):
559559
def convergence_trajectory_df(
560560
self,
561561
query=None,
562-
id_vars=("dataset_name", "coef_lasso_shift"),
562+
id_vars=("dataset_name", "scale_coeff_lasso_shift"),
563563
):
564564
"""
565565
Combine the converence trajectory dataframes of
@@ -776,7 +776,7 @@ def mut_param_traceplot(
776776
self,
777777
mutations,
778778
mut_param="shift",
779-
x="coef_lasso_shift",
779+
x="scale_coeff_lasso_shift",
780780
width_scalar=100,
781781
height_scalar=100,
782782
**kwargs,
@@ -821,7 +821,7 @@ def mut_param_traceplot(
821821
muts_df = muts_df.query("mutation.isin(@mutations)")
822822

823823
# check that we have multiple lasso penalty weights
824-
if len(muts_df.coef_lasso_shift.unique()) <= 1:
824+
if len(muts_df.scale_coeff_lasso_shift.unique()) <= 1:
825825
raise ValueError(
826826
"invalid kwargs, must specify a subset of fits with "
827827
"multiple lasso penalty weights"
@@ -834,7 +834,7 @@ def mut_type(mut):
834834
muts_df = muts_df.assign(mut_type=muts_df.mutation.apply(mut_type))
835835

836836
# melt conditions and stats cols, beta is already "tall"
837-
# id_cols = ["coef_lasso_shift", "mutation", "is_stop"]
837+
# id_cols = ["scale_coeff_lasso_shift", "mutation", "is_stop"]
838838
id_cols = ["dataset_name", x, "mut_type", "mutation"]
839839
stat_cols_to_keep = [c for c in muts_df.columns if c.startswith(mut_param)]
840840
if mut_param == "beta":
@@ -891,7 +891,7 @@ def mut_type(mut):
891891

892892
def shift_sparsity(
893893
self,
894-
x="coef_lasso_shift",
894+
x="scale_coeff_lasso_shift",
895895
width_scalar=100,
896896
height_scalar=100,
897897
return_data=False,
@@ -952,7 +952,7 @@ def mut_type(mut):
952952
alt.Chart(sparsity_df)
953953
.encode(
954954
x=alt.X(
955-
"coef_lasso_shift",
955+
"scale_coeff_lasso_shift",
956956
type="nominal",
957957
title=(
958958
PARAMETER_NAMES_FOR_PLOTTING[x]
@@ -967,7 +967,7 @@ def mut_type(mut):
967967
),
968968
color=alt.Color("mut_type", type="nominal", title="Mutation type"),
969969
tooltip=[
970-
"coef_lasso_shift",
970+
"scale_coeff_lasso_shift",
971971
"sparsity",
972972
"mut_type",
973973
],
@@ -995,7 +995,7 @@ def mut_type(mut):
995995

996996
def mut_param_dataset_correlation(
997997
self,
998-
x="coef_lasso_shift",
998+
x="scale_coeff_lasso_shift",
999999
width_scalar=200,
10001000
height=200,
10011001
return_data=False,
@@ -1012,7 +1012,7 @@ def mut_param_dataset_correlation(
10121012
----------
10131013
x : str, optional
10141014
The parameter to plot on the x-axis.
1015-
The default is "coef_lasso_shift".
1015+
The default is "scale_coeff_lasso_shift".
10161016
width_scalar : int, optional
10171017
The width of the chart. The default is 150.
10181018
height : int, optional

0 commit comments

Comments
 (0)