[bridge] Support GPTBridge callback#8
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a converter callback to the load_weights, export_weights, and save_weights methods, enabling custom key-value transformations during weight operations. The review feedback identifies potential runtime issues, including an AttributeError during loading if the converter returns a standard tensor instead of a LazyTensor, and a TypeError during export if the converter is used for filtering and returns None. A suggestion was also provided to refine the type hint for the converter argument to Optional[Callable] for better consistency.
| if converter: | ||
| state_dict = dict(converter(k, v, adapter_name=adapter_name) for k, v in state_dict.items()) |
There was a problem hiding this comment.
The converter callback receives a LazyTensor object as the value v during weight loading. If the converter returns a standard torch.Tensor (e.g., after applying some transformation), the subsequent code in _set_state_dict will fail with an AttributeError because it explicitly calls .load() on the state dict values (see line 472).
Additionally, there is a type inconsistency in the converter interface: in load_weights, v is a LazyTensor, while in export_weights, v is a torch.Tensor. This makes it difficult to write a single converter that works for both operations if it needs to access or modify the tensor data.
| if converter: | ||
| k, v = converter(k, v, adapter_name=adapter_name) |
There was a problem hiding this comment.
The unpacking k, v = converter(k, v, ...) assumes that the converter always returns a tuple of length 2. If a user wants to use the converter to filter out certain weights by returning None, this line will raise a TypeError. It is safer to check the return value before unpacking to allow for filtering.
References
- Ensure robust handling of callback return values to prevent runtime errors when users attempt to filter items. (link)
| hf_model_dir: str, | ||
| peft_format: bool = False, | ||
| adapter_name: str = 'default', | ||
| converter: Callable = None, |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces a converter callback to the load_weights, export_weights, and save_weights methods in GPTBridge, allowing for custom key-value transformations. It also exposes the adapter_name parameter in the export and save methods. A potential AttributeError was identified in load_weights where the converter might return a raw torch.Tensor instead of a LazyTensor, and a suggestion was provided to wrap the output to ensure compatibility.
| if converter: | ||
| new_state_dict = {} | ||
| for k, v in state_dict.items(): | ||
| kv = converter(k, v) | ||
| if kv is None: | ||
| continue | ||
| k, v = kv | ||
| new_state_dict[k] = v | ||
| state_dict = new_state_dict |
There was a problem hiding this comment.
The bridge's internal conversion logic (e.g., in _set_state_dict) explicitly calls .load() on every value in the state dict. If the user-provided converter returns a raw torch.Tensor instead of a LazyTensor (which happens if they call .load() inside the converter to modify the weight), the subsequent code will fail with an AttributeError. You should ensure that any value returned by the converter that doesn't have a .load() method is wrapped appropriately to maintain compatibility with the rest of the bridge.
| if converter: | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| kv = converter(k, v) | |
| if kv is None: | |
| continue | |
| k, v = kv | |
| new_state_dict[k] = v | |
| state_dict = new_state_dict | |
| if converter: | |
| from mcore_bridge.utils.safetensors import LazyTensor | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| kv = converter(k, v) | |
| if kv is None: | |
| continue | |
| k, v = kv | |
| if not hasattr(v, 'load'): | |
| v = LazyTensor(tensor=v) | |
| new_state_dict[k] = v | |
| state_dict = new_state_dict |
No description provided.