Skip to content

Commit 6b711d7

Browse files
authored
Port #1079 to skyrl folder (#1127)
See #1079 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1127" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
1 parent ba3e4db commit 6b711d7

12 files changed

Lines changed: 642 additions & 108 deletions

File tree

skyrl/skyrl/backends/jax.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
insert_adapter_state,
5454
round_up_seq_len,
5555
resolve_model_path,
56+
get_adapter_idx,
5657
)
5758
from skyrl.utils.log import logger
5859

@@ -127,15 +128,22 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated
127128
def get_mean(self, adapter_index: jax.Array) -> nnx.State:
128129
"""Compute mean gradients for a specific adapter, with zeros for all other adapters."""
129130
count = self.counts[adapter_index]
130-
return jax.tree.map(
131-
lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)),
132-
self.grad_sum,
133-
)
131+
132+
def compute_mean(path, g):
133+
idx = get_adapter_idx(path, adapter_index)
134+
return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype))
135+
136+
return jax.tree.map_with_path(compute_mean, self.grad_sum)
134137

135138
def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients":
136139
"""Reset gradients and count for a specific adapter."""
140+
141+
def reset_grad(path, g):
142+
idx = get_adapter_idx(path, adapter_index)
143+
return g.at[idx].set(0.0)
144+
137145
return AccumulatedGradients(
138-
grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum),
146+
grad_sum=jax.tree.map_with_path(reset_grad, self.grad_sum),
139147
counts=self.counts.at[adapter_index].set(0),
140148
)
141149

skyrl/skyrl/tx/layers/lora.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import jax
33
from jax import numpy as jnp
44

5-
from skyrl.tx.utils.models import filter_lora
5+
from skyrl.tx.utils.models import filter_lora, get_adapter_idx
66
from skyrl.tx.layers.util import Param, prepare_routing, ragged_dot
77
from skyrl.tx.models.types import ModelForCausalLM
88
from skyrl.tinker.types import LoraConfig
@@ -345,21 +345,22 @@ def init_adapter(path, value):
345345
if not filter_lora(lora_config, normalized_path):
346346
effective_rank = 0
347347

348+
idx = get_adapter_idx(path, adapter_index)
349+
348350
key_name = path[-2].key
349351
if key_name == "lora_ranks":
350-
return value.at[adapter_index].set(effective_rank)
352+
return value.at[idx].set(effective_rank)
351353
if key_name == "lora_scaling":
352-
# Set scaling to 0.0 if rank is 0
353-
return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0)
354+
scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0
355+
return value.at[idx].set(scaling)
354356
if key_name == "lora_A":
355357
# Reinitialize with he_uniform, then zero columns beyond rank
356-
shape = value[adapter_index].shape
357-
new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype)
358+
new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype)
358359
new_A = new_A.at[..., effective_rank:].set(0.0)
359-
return value.at[adapter_index].set(new_A)
360+
return value.at[idx].set(new_A)
360361
if key_name == "lora_B":
361362
# Explicitly zero lora_B
362-
return value.at[adapter_index].set(0.0)
363+
return value.at[idx].set(0.0)
363364
return value
364365

365366
updated_state = jax.tree.map_with_path(init_adapter, state)
@@ -376,11 +377,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):
376377

377378
def clear_adapter(path, value):
378379
key = path[-2].key
379-
if key == "lora_ranks":
380-
return value.at[adapter_index].set(0)
381-
if key in ("lora_scaling", "lora_A", "lora_B"):
382-
return value.at[adapter_index].set(0.0)
383-
return value
380+
if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
381+
return value
382+
idx = get_adapter_idx(path, adapter_index)
383+
return value.at[idx].set(0 if key == "lora_ranks" else 0.0)
384384

385385
updated_state = jax.tree.map_with_path(clear_adapter, state)
386386
nnx.update(model, updated_state)

0 commit comments

Comments
 (0)