Skip to content

Commit 9b3ef20

Browse files
authored
V1.0.1 release (#69)
* fix: Resolve "AttributeError: 'SpectrumDataFrame' object has no attribute 'df'" * feat: update notebooks to v1.0.0 * feat: Automatic model download and improve residues Co-Authored-By: Kevin Eloff <k.eloff@instadeep.com> * feat: update tests for v1.0.1 release Co-Authored-By: Rachel Catzel <r.catzel@instadeep.com> * feat: update packages
1 parent dca4423 commit 9b3ef20

29 files changed

+1070
-448
lines changed

.github/workflows/python-publish.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,12 @@ jobs:
3434
- name: Build package
3535
run: python -m build
3636
- name: Publish package
37-
uses: pypa/gh-action-pypi-publish@fb13cb306901256ace3dab689990e13a5550ffaa
37+
uses: pypa/gh-action-pypi-publish@v1.12
3838
with:
3939
user: __token__
4040
password: ${{ secrets.PYPI_API_TOKEN }}
4141
- name: refresh PyPI badge
42-
uses: fjogeleit/http-request-action@v1
42+
uses: fjogeleit/http-request-action@v1.16
4343
with:
4444
url: https://camo.githubusercontent.com/a22fbcbadf81751212d5367cce341631bc28d7749b9cd5c317fbf0706a30c9ae/68747470733a2f2f62616467652e667572792e696f2f70792f696e7374616e6f766f2e737667
4545
method: PURGE

instanovo/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from __future__ import annotations
22

3-
__version__ = "1.0.0"
3+
__version__ = "1.0.1"

instanovo/configs/inference/default.yaml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Data paths and output location
2-
data_path: # type: .mgf, .mzml or any other filetype supported by SpectruMataFrame
3-
model_path: # type: .ckpt
2+
data_path: # type: .mgf, .mzml or any other filetype supported by SpectrumDataFrame
3+
model_path: instanovo-extended # type: .ckpt or model id
44
output_path: # type: .csv
55
knapsack_path: # type: directory
66

@@ -17,9 +17,11 @@ use_knapsack: False
1717
save_beams: False
1818
subset: 1.0 # Subset of dataset to perform inference on, useful for debugging
1919

20+
# These two only work in greedy search
2021
# Residues whose log probability will be set to -inf
21-
# Only works in greedy search
22-
# suppressed_residues: TODO
22+
suppressed_residues:
23+
# Stop model from predicting n-terminal modifications anywhere along the sequence
24+
disable_terminal_residues_anywhere: True
2325

2426
# Run config
2527
num_workers: 16

instanovo/configs/inference/unit_test.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ defaults:
44

55
# Data paths and output location
66
data_path: ./tests/instanovo_test_resources/example_data/test_sample.mgf # type: .ipc
7-
model_path: ./tests/instanovo_test_resources/train_test/epoch=4-step=2420.ckpt # type: .ckpt
8-
output_path: ./tests/instanovo_test_resources/train_test/test_sample_preds.csv # type: .csv
7+
model_path: ./tests/instanovo_test_resources/model.ckpt # type: .ckpt
8+
output_path: ./tests/instanovo_test_resources/test_sample_preds.csv # type: .csv
99
knapsack_path: ./tests/instanovo_test_resources/example_knapsack # type: directory
1010
use_knapsack: False
1111

1212
num_beams: 5
1313
max_length: 30
14+
max_charge: 3
1415

1516
subset: 1
1617

instanovo/configs/instanovo.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ train_subset: 1.0
3636
valid_subset: 0.01
3737
val_check_interval: 1.0 # 1.0 This doesn't work
3838
lazy_loading: True # Use lazy loading mode
39-
max_shard_size: 100_000 # Max data shard size for lazy loading, may influence shuffling mechanics
39+
max_shard_size: 1_000_000 # Max data shard size for lazy loading, may influence shuffling mechanics
40+
preshuffle_shards: True # Perform a preshuffle across shards to ensure shards are homogeneous in lazy mode
4041
perform_data_checks: True # Check residues, check precursor masses, etc.
42+
validate_precursor_mass: False # Slow for large datasets
43+
verbose_loading: True # Verbose SDF logs when loading the dataset
4144

4245
# Checkpointing parameters
4346
save_model: True

instanovo/configs/instanovo_unit_test.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ defaults:
99
tb_summarywriter: "./logs/instanovo/instanovo-unit-test"
1010

1111
# Training parameters
12-
warmup_iters: 1000
12+
warmup_iters: 480
1313
max_iters: 3_000_000
1414
learning_rate: 1e-3
1515
train_batch_size: 32
1616
grad_accumulation: 1
1717

1818
# Logging parameters
1919
logger:
20-
epochs: 5
20+
epochs: 1
2121
num_sanity_val_steps: 10
2222
console_logging_steps: 50
2323
tensorboard_logging_steps: 500
@@ -29,4 +29,4 @@ valid_subset: 1.0
2929

3030
# Checkpointing parameters
3131
model_save_folder_path: ./tests/instanovo_test_resources/train_test
32-
ckpt_interval: 2420
32+
ckpt_interval: 480

instanovo/inference/greedy_search.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,18 @@ class GreedyDecoder(Decoder):
2121
models that conform to the `Decodable` interface.
2222
"""
2323

24-
def __init__(self, model: Decodable, mass_scale: int = MASS_SCALE):
24+
def __init__(
25+
self,
26+
model: Decodable,
27+
suppressed_residues: list[str] | None = None,
28+
mass_scale: int = MASS_SCALE,
29+
disable_terminal_residues_anywhere: bool = True,
30+
):
2531
super().__init__(model=model)
2632
self.mass_scale = mass_scale
33+
self.disable_terminal_residues_anywhere = disable_terminal_residues_anywhere
34+
35+
suppressed_residues = suppressed_residues or []
2736

2837
# NOTE: Greedy search requires `residue_set` class in the model, update all methods accordingly.
2938
if not hasattr(model, "residue_set"):
@@ -37,10 +46,32 @@ def __init__(self, model: Decodable, mass_scale: int = MASS_SCALE):
3746
self.residue_masses = torch.zeros(
3847
(len(self.model.residue_set),), dtype=torch.float64
3948
)
49+
terminal_residues_idx: list[int] = []
50+
suppressed_residues_idx: list[int] = []
4051
for i, residue in enumerate(model.residue_set.vocab):
4152
if residue in self.model.residue_set.special_tokens:
4253
continue
4354
self.residue_masses[i] = self.model.residue_set.get_mass(residue)
55+
# If no residue is attached, assume it is a n-terminal residue
56+
if not residue[0].isalpha():
57+
terminal_residues_idx.append(i)
58+
59+
# Check if residue is suppressed
60+
if residue in suppressed_residues:
61+
suppressed_residues_idx.append(i)
62+
suppressed_residues.remove(residue)
63+
64+
if len(suppressed_residues) > 0:
65+
raise ValueError(
66+
f"Suppressed residues not found in vocabulary: {suppressed_residues}"
67+
)
68+
69+
self.terminal_residue_indices = torch.tensor(
70+
terminal_residues_idx, dtype=torch.long
71+
)
72+
self.suppressed_residue_indices = torch.tensor(
73+
suppressed_residues_idx, dtype=torch.long
74+
)
4475

4576
self.vocab_size = len(self.model.residue_set)
4677

@@ -270,10 +301,53 @@ def decode( # type:ignore
270301
next_token_probabilities_filtered[
271302
:, self.model.residue_set.EOS_INDEX
272303
] = -float("inf")
304+
# Allow the model to predict PAD when all residues are -inf
305+
# next_token_probabilities_filtered[
306+
# :, self.model.residue_set.PAD_INDEX
307+
# ] = -float("inf")
273308
next_token_probabilities_filtered[
274309
:, self.model.residue_set.SOS_INDEX
275310
] = -float("inf")
276-
# TODO set probability of n-terminal modifications to 0 when i > 0, requires n-terms to be specified in residue_set
311+
next_token_probabilities_filtered[
312+
:, self.suppressed_residue_indices
313+
] = -float("inf")
314+
# Set probability of n-terminal modifications to -inf when i > 0
315+
if self.disable_terminal_residues_anywhere:
316+
# Check if adding terminal residues would result in a complete sequence
317+
# First generate remaining mass matrix with isotopes
318+
remaining_mass_incomplete_isotope = remaining_mass_incomplete[
319+
:, None
320+
].expand(sub_batch_size, max_isotope + 1) - CARBON_MASS_DELTA * (
321+
torch.arange(max_isotope + 1, device=device)
322+
)
323+
# Expand with terminal residues and subtract
324+
remaining_mass_incomplete_isotope_delta = (
325+
remaining_mass_incomplete_isotope[:, :, None].expand(
326+
sub_batch_size,
327+
max_isotope + 1,
328+
self.terminal_residue_indices.shape[0],
329+
)
330+
- self.residue_masses[self.terminal_residue_indices]
331+
)
332+
333+
# If within target delta, allow these residues to be predicted, otherwise set probability to -inf
334+
allow_terminal = (
335+
remaining_mass_incomplete_isotope_delta.abs()
336+
< mass_target_incomplete[:, None, None]
337+
).any(dim=1)
338+
allow_terminal_full = torch.ones(
339+
(sub_batch_size, self.vocab_size),
340+
device=spectra.device,
341+
dtype=bool,
342+
)
343+
allow_terminal_full[:, self.terminal_residue_indices] = (
344+
allow_terminal
345+
)
346+
347+
# Set to -inf
348+
next_token_probabilities_filtered[~allow_terminal_full] = -float(
349+
"inf"
350+
)
277351

278352
# Step 5: Select next token:
279353
next_token = next_token_probabilities_filtered.argmax(-1).unsqueeze(
@@ -362,7 +436,7 @@ def decode( # type:ignore
362436
token_log_probabilities=[
363437
x.cpu().item()
364438
for x in all_log_probabilities[i, : len(sequence)]
365-
], # list[float] (sequence_length) excludes EOS
439+
][::-1], # list[float] (sequence_length) excludes EOS
366440
)
367441
)
368442

instanovo/models.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"transformer": {
3+
"instanovo-extended": {
4+
"url": "https://github.com/instadeepai/InstaNovo/releases/download/1.0.0/instanovo_extended.ckpt"
5+
}
6+
},
7+
"diffusion": {}
8+
}

instanovo/scripts/convert_to_sdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
def main() -> None:
13-
"""Convert data to ipc."""
13+
"""Convert data to spectrum data frame and save as parquet."""
1414
logging.basicConfig(level=logging.INFO)
1515
parser = argparse.ArgumentParser()
1616

instanovo/scripts/get_zenodo_record.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,13 @@ def main(
6868
extract_path + "/instanovo_test_resources"
6969
):
7070
print(
71-
f"Skipping download and extraction. Path '{extract_path}'/instanovo_test_resources already exists and is non-empty."
71+
f"Skipping download and extraction. Path '{extract_path}/instanovo_test_resources' already exists and is non-empty."
7272
)
7373
return
7474

7575
get_zenodo(zenodo_url, zip_path)
7676
unzip_zenodo(zip_path, extract_path)
7777

78-
os.makedirs("./tests/instanovo_test_resources/train_test", exist_ok=True)
79-
8078

8179
if __name__ == "__main__":
8280
main()

0 commit comments

Comments
 (0)