Skip to content

Commit 0cd2d35

Browse files
tomwardiocopybara-github
authored andcommitted
Add support for center mask scoring over entire sequence
PiperOrigin-RevId: 796402809 Change-Id: Ic45642fd777f776a36a75a0d0ccc50542cf2edd1
1 parent 3ca50c7 commit 0cd2d35

3 files changed

Lines changed: 20 additions & 5 deletions

File tree

src/alphagenome/models/variant_scorers.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ class BaseVariantScorer(enum.Enum):
138138
})
139139

140140
SUPPORTED_WIDTHS = immutabledict.immutabledict({
141-
BaseVariantScorer.CENTER_MASK: [501, 2001, 10_001, 100_001, 200_001],
142-
BaseVariantScorer.GENE_MASK_SPLICING: [101, 1_001, 10_001, None],
141+
BaseVariantScorer.CENTER_MASK: [None, 501, 2001, 10_001, 100_001, 200_001],
142+
BaseVariantScorer.GENE_MASK_SPLICING: [None, 101, 1_001, 10_001],
143143
})
144144

145145
SUPPORTED_AGGREGATIONS = immutabledict.immutabledict({
@@ -166,7 +166,8 @@ class CenterMaskScorer:
166166
167167
Attributes:
168168
requested_output: The requested output type (e.g., ATAC, DNASE, etc.)
169-
width: The width of the mask around the variant.
169+
width: The width of the mask around the variant. If None, the score is
170+
computed over the entire sequence.
170171
aggregation_type: The aggregation type.
171172
base_variant_scorer: The base variant scorer.
172173
name: The name of the scorer (a composite of the above attributes that
@@ -183,7 +184,7 @@ class CenterMaskScorer:
183184
"""
184185

185186
requested_output: dna_output.OutputType
186-
width: int
187+
width: int | None
187188
aggregation_type: AggregationType
188189

189190
@property

src/alphagenome/models/variant_scorers_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,15 @@ def test_variant_scorers_to_proto(self):
6464
requested_output=dna_model_pb2.OUTPUT_TYPE_PROCAP,
6565
)
6666
),
67+
dna_model_pb2.VariantScorer(
68+
center_mask=dna_model_pb2.CenterMaskScorer(
69+
width=None,
70+
aggregation_type=(
71+
dna_model_pb2.AggregationType.AGGREGATION_TYPE_DIFF_LOG2_SUM
72+
),
73+
requested_output=dna_model_pb2.OUTPUT_TYPE_ATAC,
74+
)
75+
),
6776
dna_model_pb2.VariantScorer(
6877
gene_mask=dna_model_pb2.GeneMaskLFCScorer(
6978
requested_output=dna_model_pb2.OUTPUT_TYPE_RNA_SEQ,
@@ -93,6 +102,11 @@ def test_variant_scorers_to_proto(self):
93102
aggregation_type=(variant_scorers.AggregationType.ACTIVE_MEAN),
94103
requested_output=dna_output.OutputType.PROCAP,
95104
),
105+
variant_scorers.CenterMaskScorer(
106+
width=None,
107+
aggregation_type=(variant_scorers.AggregationType.DIFF_LOG2_SUM),
108+
requested_output=dna_output.OutputType.ATAC,
109+
),
96110
variant_scorers.GeneMaskLFCScorer(
97111
requested_output=dna_output.OutputType.RNA_SEQ,
98112
),

src/alphagenome/protos/dna_model.proto

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ message IntervalScorer {
295295
// Variant scorer message for center mask scoring.
296296
message CenterMaskScorer {
297297
// The width of the mask around the variant.
298-
int64 width = 1;
298+
optional int64 width = 1;
299299

300300
// The aggregation type.
301301
AggregationType aggregation_type = 2;

0 commit comments

Comments
 (0)