Skip to content

Commit f523ffa

Browse files
jgallowa07claude
andauthored
Add masking support to ParentIndependentBinarySelectionModel (#166)
* Add masking support to ParentIndependentBinarySelectionModel Fixes #165 - Implement multiplicative masking in forward method following pattern used by other models - Masked positions (mask=False) get multiplied by 0 in log space, resulting in selection factor of 1 after exp - Add comprehensive unit tests covering 1D and multi-dimensional outputs, edge cases, and gradient flow - Update docstring to document mask behavior - Maintain backward compatibility with existing code 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> * Fix formatting and linting issues in test file - Remove unused numpy import - Fix E712 violations by using proper boolean operations (~mask instead of mask == False) - All CI checks now pass: checkformat, lint, checktodo - Tests continue to pass after formatting fixes 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]> --------- Co-authored-by: Claude <[email protected]>
1 parent 55cae8e commit f523ffa

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed

netam/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,10 +1057,13 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
10571057
Since this is a parent-independent model, the amino_acid_indices parameter
10581058
is ignored - we only use it to determine batch size and sequence length.
10591059
The selection factors depend only on position, not on the parent sequence.
1060+
Masked positions (where mask is False) will have their log selection factors
1061+
set to 0, which becomes a selection factor of 1 after exponentiation.
10601062
10611063
Args:
10621064
amino_acid_indices: A tensor of shape (B, L) - used only for shape.
10631065
mask: A tensor of shape (B, L) representing the mask of valid amino acid sites.
1066+
True for valid positions, False for positions to be masked.
10641067
10651068
Returns:
10661069
A tensor of shape (B, L) or (B, L, output_dim) with log selection factors.
@@ -1073,6 +1076,8 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
10731076
# Expand to match the required batch and output shape
10741077
if self.output_dim == 1:
10751078
result = position_factors.expand(batch_size, seq_len)
1079+
# Apply masking: multiplicative masking in log space (consistent with other models)
1080+
result = result * mask
10761081
else:
10771082
# Create a proper copy instead of a view to avoid in-place operation issues
10781083
result = (
@@ -1091,6 +1096,9 @@ def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
10911096
# Set wildtype aa selection factors to 0 (which becomes 1 after exp)
10921097
result = zap_predictions_along_diagonal(result, wt_idxs_batch, fill=0.0)
10931098

1099+
# Apply masking: expand mask to match output dimensions
1100+
result = result * mask.unsqueeze(-1)
1101+
10941102
return result
10951103

10961104

Lines changed: 249 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
"""Tests for ParentIndependentBinarySelectionModel mask functionality."""
2+
3+
import pytest
4+
import torch
5+
6+
from netam.models import ParentIndependentBinarySelectionModel
7+
8+
9+
class TestParentIndependentBinarySelectionModelMask:
10+
"""Test mask functionality for ParentIndependentBinarySelectionModel."""
11+
12+
@pytest.fixture
13+
def model_1d(self):
14+
"""Create a ParentIndependentBinarySelectionModel with output_dim=1."""
15+
return ParentIndependentBinarySelectionModel(
16+
output_dim=1, known_token_count=21, model_type="test"
17+
)
18+
19+
@pytest.fixture
20+
def model_multidim(self):
21+
"""Create a ParentIndependentBinarySelectionModel with output_dim=20."""
22+
return ParentIndependentBinarySelectionModel(
23+
output_dim=20, known_token_count=21, model_type="test"
24+
)
25+
26+
@pytest.fixture
27+
def model_with_wildtype(self):
28+
"""Create a ParentIndependentBinarySelectionModel with wildtype sequence."""
29+
return ParentIndependentBinarySelectionModel(
30+
output_dim=20,
31+
wildtype_sequence="ACDEFG",
32+
known_token_count=21,
33+
model_type="test",
34+
)
35+
36+
@pytest.fixture
37+
def sample_inputs(self):
38+
"""Create sample inputs for testing."""
39+
batch_size, seq_len = 2, 6
40+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
41+
# Create mask with some positions masked (False)
42+
mask = torch.tensor(
43+
[
44+
[True, True, False, True, False, True],
45+
[True, False, True, True, True, False],
46+
],
47+
dtype=torch.bool,
48+
)
49+
return amino_acid_indices, mask
50+
51+
def test_mask_application_1d(self, model_1d, sample_inputs):
52+
"""Test that masking works correctly for 1D output."""
53+
amino_acid_indices, mask = sample_inputs
54+
55+
# Get output with mask
56+
result = model_1d.forward(amino_acid_indices, mask)
57+
58+
# Check that result has correct shape
59+
assert result.shape == (2, 6)
60+
61+
# Check that masked positions have value 0 (mask=False means multiply by 0)
62+
assert torch.all(result[~mask] == 0.0)
63+
64+
# Check that unmasked positions retain their original values
65+
# Compare with result when mask is all True
66+
all_true_mask = torch.ones_like(mask, dtype=torch.bool)
67+
unmasked_result = model_1d.forward(amino_acid_indices, all_true_mask)
68+
69+
# Unmasked positions should match
70+
assert torch.allclose(result[mask], unmasked_result[mask])
71+
72+
def test_mask_application_multidim(self, model_multidim, sample_inputs):
73+
"""Test that masking works correctly for multi-dimensional output."""
74+
amino_acid_indices, mask = sample_inputs
75+
76+
# Get output with mask
77+
result = model_multidim.forward(amino_acid_indices, mask)
78+
79+
# Check that result has correct shape
80+
assert result.shape == (2, 6, 20)
81+
82+
# Check that masked positions have value 0 across all output dimensions
83+
masked_positions = ~mask
84+
assert torch.all(result[masked_positions] == 0.0)
85+
86+
# Check that unmasked positions retain their original values
87+
all_true_mask = torch.ones_like(mask, dtype=torch.bool)
88+
unmasked_result = model_multidim.forward(amino_acid_indices, all_true_mask)
89+
90+
# Unmasked positions should match
91+
unmasked_positions = mask
92+
assert torch.allclose(
93+
result[unmasked_positions], unmasked_result[unmasked_positions]
94+
)
95+
96+
def test_mask_with_wildtype_zapping(self, model_with_wildtype, sample_inputs):
97+
"""Test that masking works correctly with wildtype zapping."""
98+
amino_acid_indices, mask = sample_inputs
99+
100+
# Adjust input to match wildtype sequence length
101+
amino_acid_indices = amino_acid_indices[
102+
:, :6
103+
] # Model has wildtype sequence of length 6
104+
mask = mask[:, :6]
105+
106+
result = model_with_wildtype.forward(amino_acid_indices, mask)
107+
108+
# Check shape
109+
assert result.shape == (2, 6, 20)
110+
111+
# Check that masked positions are 0
112+
masked_positions = ~mask
113+
assert torch.all(result[masked_positions] == 0.0)
114+
115+
def test_all_masked(self, model_1d):
116+
"""Test behavior when all positions are masked."""
117+
batch_size, seq_len = 2, 6
118+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
119+
mask = torch.zeros((batch_size, seq_len), dtype=torch.bool) # All False
120+
121+
result = model_1d.forward(amino_acid_indices, mask)
122+
123+
# All positions should be 0
124+
assert torch.all(result == 0.0)
125+
126+
def test_all_unmasked(self, model_1d):
127+
"""Test behavior when no positions are masked."""
128+
batch_size, seq_len = 2, 6
129+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
130+
mask = torch.ones((batch_size, seq_len), dtype=torch.bool) # All True
131+
132+
result = model_1d.forward(amino_acid_indices, mask)
133+
134+
# Should be same as model's learned parameters
135+
expected = model_1d.log_selection_factors[:seq_len].expand(batch_size, seq_len)
136+
assert torch.allclose(result, expected)
137+
138+
def test_gradient_flow_masked_positions(self, model_1d, sample_inputs):
139+
"""Test that gradients don't flow through masked positions."""
140+
amino_acid_indices, mask = sample_inputs
141+
142+
# Enable gradients
143+
model_1d.train()
144+
145+
# Forward pass
146+
result = model_1d.forward(amino_acid_indices, mask)
147+
148+
# Create a simple loss that only depends on the result
149+
loss = result.sum()
150+
151+
# Backward pass
152+
loss.backward()
153+
154+
# Check that gradients exist for the model parameters
155+
assert model_1d.log_selection_factors.grad is not None
156+
157+
# The gradient contribution from masked positions should be 0
158+
# This is automatically handled by the multiplication by 0
159+
160+
def test_mask_consistency_across_batches(self, model_1d):
161+
"""Test that masking is applied consistently across batch dimensions."""
162+
batch_size, seq_len = 3, 5
163+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
164+
165+
# Create mask where same positions are masked across all batches
166+
mask = torch.tensor(
167+
[
168+
[True, False, True, False, True],
169+
[True, False, True, False, True],
170+
[True, False, True, False, True],
171+
],
172+
dtype=torch.bool,
173+
)
174+
175+
result = model_1d.forward(amino_acid_indices, mask)
176+
177+
# Masked positions (index 1 and 3) should be 0 for all batches
178+
assert torch.all(result[:, 1] == 0.0)
179+
assert torch.all(result[:, 3] == 0.0)
180+
181+
# Unmasked positions should have same values across batches (since they come from position-specific parameters)
182+
for pos in [0, 2, 4]:
183+
# All batches should have same value at this position (from position-specific params)
184+
assert torch.allclose(result[0, pos], result[1, pos])
185+
assert torch.allclose(result[1, pos], result[2, pos])
186+
187+
def test_different_masks_per_batch(self, model_1d):
188+
"""Test that different masks can be applied to different sequences in the
189+
batch."""
190+
batch_size, seq_len = 2, 4
191+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
192+
193+
# Different masks for each sequence in batch
194+
mask = torch.tensor(
195+
[
196+
[True, False, True, True], # Second position masked
197+
[True, True, False, True], # Third position masked
198+
],
199+
dtype=torch.bool,
200+
)
201+
202+
result = model_1d.forward(amino_acid_indices, mask)
203+
204+
# Check that different positions are masked for each sequence
205+
assert result[0, 1] == 0.0 # Second position of first sequence
206+
assert result[1, 2] == 0.0 # Third position of second sequence
207+
208+
# Check that non-masked positions are non-zero (assuming learned parameters are non-zero)
209+
# We can't assume they're non-zero since they're initialized to zero, but they should be equal to the learned parameters
210+
expected_vals = model_1d.log_selection_factors[:seq_len]
211+
assert result[0, 0] == expected_vals[0]
212+
assert result[1, 1] == expected_vals[1]
213+
214+
def test_mask_device_compatibility(self, model_1d):
215+
"""Test that masking works correctly when tensors are on different devices."""
216+
batch_size, seq_len = 2, 4
217+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
218+
mask = torch.ones((batch_size, seq_len), dtype=torch.bool)
219+
220+
# Move model to CPU (it should already be there, but make sure)
221+
model_1d.to("cpu")
222+
amino_acid_indices = amino_acid_indices.to("cpu")
223+
mask = mask.to("cpu")
224+
225+
# Should work without errors
226+
result = model_1d.forward(amino_acid_indices, mask)
227+
assert result.device.type == "cpu"
228+
229+
def test_mask_dtype_handling(self, model_1d):
230+
"""Test that different mask dtypes are handled correctly."""
231+
batch_size, seq_len = 2, 4
232+
amino_acid_indices = torch.randint(0, 21, (batch_size, seq_len))
233+
234+
# Test with float mask (0.0 and 1.0)
235+
float_mask = torch.tensor(
236+
[[1.0, 0.0, 1.0, 1.0], [1.0, 1.0, 0.0, 1.0]], dtype=torch.float32
237+
)
238+
239+
result_float = model_1d.forward(amino_acid_indices, float_mask)
240+
241+
# Test with bool mask
242+
bool_mask = torch.tensor(
243+
[[True, False, True, True], [True, True, False, True]], dtype=torch.bool
244+
)
245+
246+
result_bool = model_1d.forward(amino_acid_indices, bool_mask)
247+
248+
# Results should be the same
249+
assert torch.allclose(result_float, result_bool)

0 commit comments

Comments
 (0)