Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def parse_args():
)
ap.add_argument(
"--log-prob-collapse-option",
choices=["sum", "mean"],
choices=["sum", "mean", "per_token"],
default="mean",
help="How to collapse the log probabilities across the sequence dimension.",
)
Expand Down Expand Up @@ -160,7 +160,7 @@ def __init__(
self,
*args,
output_log_prob_seqs: bool = False,
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
**kwargs,
):
"""Initialize the predictor with our needs around computing log probabilities."""
Expand Down Expand Up @@ -195,10 +195,14 @@ def predict_step(self, batch, batch_idx: int | None = None) -> Tensor:
2, # along the vocab dimension...
input_ids.unsqueeze(-1), # using the token ids to index.
).squeeze(-1)
log_prob_seqs = torch.sum(logprobs * batch["loss_mask"][:, 1:].float(), dim=-1)
if self.log_prob_collapse_option == "mean":
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
log_prob_per_token = logprobs * batch["loss_mask"][:, 1:].float()
if self.log_prob_collapse_option == "per_token":
return {"log_probs_seqs": log_prob_per_token.cpu(), "seq_idx": batch["seq_idx"].cpu()}
else:
log_prob_seqs = torch.sum(log_prob_per_token, dim=1)
if self.log_prob_collapse_option == "mean":
log_prob_seqs = log_prob_seqs / (batch["loss_mask"][:, 1:].float().sum(dim=-1) + 1e-8)
return {"log_probs_seqs": log_prob_seqs.cpu(), "seq_idx": batch["seq_idx"].cpu()}
else:
# If the user wants to match back to logits, then they will need to do the offsetting logic themselves.
return {
Expand Down Expand Up @@ -504,7 +508,7 @@ def __init__(
config,
tokenizer=None,
output_log_prob_seqs: bool = False,
log_prob_collapse_option: Literal["sum", "mean"] = "mean",
log_prob_collapse_option: Literal["sum", "mean", "per_token"] = "mean",
):
"""Initialize the MambaPredictor, which wraps the mamba model for prediction handling model parallelism.

Expand Down