Skip to content

Commit 0f2e8da

Browse files
authored
enable per output token likelihood prediction for evo2 (#1057)
### Description <!-- Provide a detailed description of the changes in this PR --> ### Type of changes <!-- Mark the relevant option with an [x] --> - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): ### CI Pipeline Configuration Configure CI behavior by applying the relevant labels: - [SKIP_CI](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#skip_ci) - Skip all continuous integration tests - [INCLUDE_NOTEBOOKS_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_notebooks_tests) - Execute notebook validation tests in pytest - [INCLUDE_SLOW_TESTS](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/contributing/contributing.md#include_slow_tests) - Execute tests labelled as slow in pytest for extensive testing > [!NOTE] > By default, the notebooks validation tests are skipped unless explicitly enabled. #### Authorizing CI Runs We use [copy-pr-bot](https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/#automation) to manage authorization of CI runs on NVIDIA's compute resources. - If a pull request is opened by a trusted user and contains only trusted changes, the pull request's code will automatically be copied to a pull-request/ prefixed branch in the source repository (e.g. pull-request/123) - If a pull request is opened by an untrusted user or contains untrusted changes, an NVIDIA org member must leave an `/ok to test` comment on the pull request to trigger CI. This will need to be done for each new commit. ### Usage <!--- How does a user interact with the changed code --> ```python # TODO: Add code snippet ``` ### Pre-submit Checklist <!--- Ensure all items are completed before submitting --> - [ ] I have tested these changes locally - [ ] I have updated the documentation accordingly - [ ] I have added/updated tests as needed - [ ] All existing tests pass successfully Signed-off-by: Yang Zhang <yangzhang@nvidia.com>
1 parent 340cd2c commit 0f2e8da

File tree

1 file changed

+11
-7
lines changed
  • sub-packages/bionemo-evo2/src/bionemo/evo2/run

1 file changed

+11
-7
lines changed

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def parse_args():
111111
)
112112
ap.add_argument(
113113
"--log-prob-collapse-option",
114-
choices=["sum", "mean"],
114+
choices=["sum", "mean", "per_token"],
115115
default="mean",
116116
help="How to collapse the log probabilities across the sequence dimension.",
117117
)
@@ -160,7 +160,7 @@ def __init__(
160160
self,
161161
*args,
162162
output_log_prob_seqs: bool = False,
163-
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
163+
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
164164
**kwargs,
165165
):
166166
"""Initialize the predictor with our needs around computing log probabilities."""
@@ -195,10 +195,14 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor:
195195
2, # along the vocab dimension...
196196
input_ids.unsqueeze(-1), # using the token ids to index.
197197
).squeeze(-1)
198-
log_prob_seqs = torch.sum(logprobs * batch["loss_mask"][:, 1:].float(), dim=-1)
199-
if self.log_prob_collapse_option == "mean":
200-
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
201-
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
198+
log_prob_per_token = logprobs * batch["loss_mask"][:, 1:].float()
199+
if self.log_prob_collapse_option == "per_token":
200+
return {"log_probs_seqs": log_prob_per_token.cpu(), "seq_idx": batch["seq_idx"].cpu()}
201+
else:
202+
log_prob_seqs = torch.sum(log_prob_per_token, dim=1)
203+
if self.log_prob_collapse_option == "mean":
204+
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
205+
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
202206
else:
203207
# If the user wants to match back to logits, then they will need to do the offsetting logic themselves.
204208
return {
@@ -504,7 +508,7 @@ def __init__(
504508
config,
505509
tokenizer=None,
506510
output_log_prob_seqs: bool = False,
507-
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
511+
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
508512
):
509513
"""Initialize the MambaPredictor, which wraps the mamba model for prediction handling model parallelism.
510514

0 commit comments

Comments
 (0)