optimize generative_reranker memory#115
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes the generative reranker task by introducing a new _forward_generative_reranker method, which computes logits only for the positive and negative target tokens instead of the entire vocabulary. Feedback on these changes identifies a critical correctness bug when Sequence Parallelism is enabled, as sequence slices would be incorrectly summed during the all-reduce step. Additionally, the feedback points out that the current implementation ignores the output layer's bias and lacks defensive checks for the tokenizer and token IDs, and provides a detailed code suggestion to resolve these issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request optimizes the generative reranker task by only calculating logits for the positive and negative tokens instead of the entire vocabulary. This is implemented via the new _init_reranker_cache and _forward_generative_reranker methods in GPTModel. Feedback highlights a critical issue where the newly initialized cache tensors are created on the CPU by default, which will cause device mismatch errors during the forward pass. Additionally, the feedback suggests matching the mask's data type to the weight's data type to avoid type promotion issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request optimizes the generative reranker task in gpt_model.py by computing logits only for the positive and negative tokens instead of the entire vocabulary, supporting both Tensor and Sequence Parallelism. The feedback suggests adding defensive checks to ensure the tokenizer is initialized and token IDs are successfully retrieved, as well as casting the boolean mask to the weight's data type to avoid type promotion warnings and performance issues.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def _init_reranker_cache(self, weight): | ||
| """One-time initialization of generative reranker constants.""" | ||
| positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes') | ||
| negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no') | ||
| positive_token_id = self.tokenizer.convert_tokens_to_ids(positive_token) | ||
| negative_token_id = self.tokenizer.convert_tokens_to_ids(negative_token) | ||
|
|
||
| token_ids = torch.tensor([positive_token_id, negative_token_id], dtype=torch.long, device=weight.device) |
There was a problem hiding this comment.
To ensure robust defensive programming, we should verify that self.tokenizer is initialized on the model before attempting to access it. Additionally, we should check if the token IDs for the positive and negative tokens are successfully retrieved (i.e., not None) to prevent a cryptic TypeError when creating the PyTorch tensor.
def _init_reranker_cache(self, weight):
"""One-time initialization of generative reranker constants."""
if getattr(self, 'tokenizer', None) is None:
raise ValueError("Tokenizer is not initialized on the model, which is required for the generative reranker.")
positive_token = os.environ.get('GENERATIVE_RERANKER_POSITIVE_TOKEN', 'yes')
negative_token = os.environ.get('GENERATIVE_RERANKER_NEGATIVE_TOKEN', 'no')
positive_token_id = self.tokenizer.convert_tokens_to_ids(positive_token)
negative_token_id = self.tokenizer.convert_tokens_to_ids(negative_token)
if positive_token_id is None or negative_token_id is None:
raise ValueError(
f"Could not find token IDs for positive token '{positive_token}' or negative token '{negative_token}' in the tokenizer vocabulary."
)
token_ids = torch.tensor([positive_token_id, negative_token_id], dtype=torch.long, device=weight.device)
No description provided.