Port https://github.com/NovaSky-AI/SkyRL/pull/1095 to skyrl folder#1129
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the model loading logic within the tests by centralizing it into a load_model helper function. This is a great improvement that simplifies the test files, removes redundant code, and eliminates the need for saving models to temporary directories. The changes make the tests cleaner and more maintainable. However, I've noticed a consistent omission of the shard_attention_heads=True parameter in the new load_model calls across multiple tests. This parameter was explicitly set in the previous implementation, and its absence could alter the model configuration and potentially impact the correctness of the tests. I've added specific comments with suggestions to address this.
| gradient_checkpointing=False, | ||
| ) |
There was a problem hiding this comment.
The shard_attention_heads=True parameter seems to be missing in the call to load_model. The original load_model function in this file hardcoded this parameter. Its omission in the refactored code might change the model's configuration and affect the test's correctness. It should be added to maintain consistency with the previous behavior.
| gradient_checkpointing=False, | |
| ) | |
| gradient_checkpointing=False, | |
| shard_attention_heads=True, | |
| ) |
| model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size) | ||
| outputs = model(input_ids, attention_mask=attention_mask) | ||
| logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) | ||
| common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False) |
There was a problem hiding this comment.
The shard_attention_heads=True parameter seems to be missing from common_kwargs. The original model loading logic included this parameter. To ensure the test behaves as it did before the refactoring, this parameter should be added to the common keyword arguments passed to load_model.
| common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False) | |
| common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False, shard_attention_heads=True) |
| max_lora_rank=32, | ||
| ) |
There was a problem hiding this comment.
The shard_attention_heads=True parameter is missing from the load_model call. The original code explicitly set shard_attention_heads=True when creating the Qwen3Config. This parameter should be added to the load_model call to ensure the model configuration remains the same after refactoring.
| max_lora_rank=32, | |
| ) | |
| max_lora_rank=32, | |
| shard_attention_heads=True, | |
| ) |
| max_lora_rank=max(cfg.r for cfg in lora_configs), | ||
| ) |
There was a problem hiding this comment.
The shard_attention_heads=True parameter is missing from the load_model call. The original code explicitly set this parameter when creating the Qwen3Config. To maintain the original test behavior, it should be included in the call to load_model.
| max_lora_rank=max(cfg.r for cfg in lora_configs), | |
| ) | |
| max_lora_rank=max(cfg.r for cfg in lora_configs), | |
| shard_attention_heads=True, | |
| ) |
See #1095