Skip to content

Conversation

@jgallowa07
Copy link
Member

Summary

Implements masking functionality for ParentIndependentBinarySelectionModel to properly handle masked positions during training and inference.

Fixes #165

Changes Made

Core Implementation

  • Modified ParentIndependentBinarySelectionModel.forward(): Added proper mask application using multiplicative masking in log space
  • Consistent with existing models: Follows the same masking pattern used by FivemerModel, SHMoofModel, and other models in the codebase
  • Handles both output dimensions: Works correctly for both 1D (output_dim=1) and multi-dimensional (output_dim>=20) outputs
  • Preserves wildtype zapping: Masking is applied after wildtype sequence zapping for multi-dimensional outputs

Key Features

  • Multiplicative masking: Masked positions (where mask=False) are multiplied by 0 in log space
  • Selection factor behavior: After exponentiation, masked positions have selection factor of 1 (neutral)
  • Gradient handling: Masked positions naturally don't contribute to gradients
  • Backward compatibility: All existing functionality preserved when mask is all True

Testing

  • Comprehensive test suite: Added test_parent_independent_mask.py with 10 test cases
  • Coverage includes:
    • 1D and multi-dimensional output masking
    • Interaction with wildtype zapping
    • Edge cases (all masked, all unmasked)
    • Gradient flow verification
    • Batch consistency and different masks per sequence
    • Device compatibility and dtype handling
  • All tests pass: Both new tests and existing model factory tests

Documentation

  • Updated docstring: Added clear documentation of mask parameter behavior
  • Inline comments: Explained masking implementation and consistency rationale

Test Results

# New tests
$ pixi run pytest tests/test_parent_independent_mask.py -v
============================== 10 passed in 0.04s ==============================

# Existing compatibility tests
$ pixi run pytest tests/test_model_factory.py -v
======================== 24 passed, 1 skipped in 0.04s =========================

Technical Details

Before (mask parameter ignored)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
    batch_size, seq_len = amino_acid_indices.shape
    position_factors = self.log_selection_factors[:seq_len]
    result = position_factors.expand(batch_size, seq_len)
    # mask parameter completely ignored!
    return result

After (proper mask application)

def forward(self, amino_acid_indices: Tensor, mask: Tensor) -> Tensor:
    batch_size, seq_len = amino_acid_indices.shape
    position_factors = self.log_selection_factors[:seq_len]
    
    if self.output_dim == 1:
        result = position_factors.expand(batch_size, seq_len)
        result = result * mask  # Apply masking
    else:
        result = position_factors.unsqueeze(0).expand(batch_size, seq_len, self.output_dim).clone()
        # Apply wildtype zapping first if applicable
        if self.output_dim >= 20 and self.wildtype_aa_idxs is not None:
            wt_idxs_batch = self.wildtype_aa_idxs[:seq_len].unsqueeze(0).expand(batch_size, -1)
            result = zap_predictions_along_diagonal(result, wt_idxs_batch, fill=0.0)
        # Then apply masking
        result = result * mask.unsqueeze(-1)
    
    return result

Impact

  • Training: Loss calculations can now properly exclude masked positions
  • Inference: Model handles variable-length sequences and specific position exclusions
  • Consistency: Masking behavior now matches other models in the codebase
  • No breaking changes: Existing code continues to work unchanged

jgallowa07 and others added 2 commits September 15, 2025 12:28
Fixes #165

- Implement multiplicative masking in forward method following pattern used by other models
- Masked positions (mask=False) get multiplied by 0 in log space, resulting in selection factor of 1 after exp
- Add comprehensive unit tests covering 1D and multi-dimensional outputs, edge cases, and gradient flow
- Update docstring to document mask behavior
- Maintain backward compatibility with existing code

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
- Remove unused numpy import
- Fix E712 violations by using proper boolean operations (~mask instead of mask == False)
- All CI checks now pass: checkformat, lint, checktodo
- Tests continue to pass after formatting fixes

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>
@jgallowa07 jgallowa07 requested a review from matsen September 15, 2025 19:43
@jgallowa07 jgallowa07 merged commit f523ffa into main Sep 15, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add masking support for ParentIndependentBinarySelectionModel to exclude specific sites from selection factor calculations

3 participants