Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions skyrl/skyrl/backends/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
insert_adapter_state,
round_up_seq_len,
resolve_model_path,
get_adapter_idx,
)
from skyrl.utils.log import logger

Expand Down Expand Up @@ -127,15 +128,22 @@ def add(self, lora_grads: nnx.State, adapter_indices: jax.Array) -> "Accumulated
def get_mean(self, adapter_index: jax.Array) -> nnx.State:
"""Compute mean gradients for a specific adapter, with zeros for all other adapters."""
count = self.counts[adapter_index]
return jax.tree.map(
lambda g: jnp.zeros_like(g).at[adapter_index].set(g[adapter_index] / count.astype(g.dtype)),
self.grad_sum,
)

def compute_mean(path, g):
idx = get_adapter_idx(path, adapter_index)
return jnp.zeros_like(g).at[idx].set(g[idx] / count.astype(g.dtype))

return jax.tree.map_with_path(compute_mean, self.grad_sum)

def reset_adapter(self, adapter_index: jax.Array) -> "AccumulatedGradients":
"""Reset gradients and count for a specific adapter."""

def reset_grad(path, g):
idx = get_adapter_idx(path, adapter_index)
return g.at[idx].set(0.0)

return AccumulatedGradients(
grad_sum=jax.tree.map(lambda g: g.at[adapter_index].set(0.0), self.grad_sum),
grad_sum=jax.tree.map_with_path(reset_grad, self.grad_sum),
counts=self.counts.at[adapter_index].set(0),
)

Expand Down
26 changes: 13 additions & 13 deletions skyrl/skyrl/tx/layers/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import jax
from jax import numpy as jnp

from skyrl.tx.utils.models import filter_lora
from skyrl.tx.utils.models import filter_lora, get_adapter_idx
from skyrl.tx.layers.util import Param, prepare_routing, ragged_dot
from skyrl.tx.models.types import ModelForCausalLM
from skyrl.tinker.types import LoraConfig
Expand Down Expand Up @@ -345,21 +345,22 @@ def init_adapter(path, value):
if not filter_lora(lora_config, normalized_path):
effective_rank = 0

idx = get_adapter_idx(path, adapter_index)

key_name = path[-2].key
if key_name == "lora_ranks":
return value.at[adapter_index].set(effective_rank)
return value.at[idx].set(effective_rank)
if key_name == "lora_scaling":
# Set scaling to 0.0 if rank is 0
return value.at[adapter_index].set(lora_config.alpha / effective_rank if effective_rank > 0 else 0.0)
scaling = lora_config.alpha / effective_rank if effective_rank > 0 else 0.0
return value.at[idx].set(scaling)
if key_name == "lora_A":
# Reinitialize with he_uniform, then zero columns beyond rank
shape = value[adapter_index].shape
new_A = nnx.initializers.he_uniform()(rngs.params(), shape, value.dtype)
new_A = nnx.initializers.he_uniform()(rngs.params(), value[idx].shape, value.dtype)
new_A = new_A.at[..., effective_rank:].set(0.0)
return value.at[adapter_index].set(new_A)
return value.at[idx].set(new_A)
if key_name == "lora_B":
# Explicitly zero lora_B
return value.at[adapter_index].set(0.0)
return value.at[idx].set(0.0)
return value

updated_state = jax.tree.map_with_path(init_adapter, state)
Expand All @@ -376,11 +377,10 @@ def clear_lora_adapter(model: ModelForCausalLM, adapter_index: int):

def clear_adapter(path, value):
key = path[-2].key
if key == "lora_ranks":
return value.at[adapter_index].set(0)
if key in ("lora_scaling", "lora_A", "lora_B"):
return value.at[adapter_index].set(0.0)
return value
if key not in ("lora_ranks", "lora_scaling", "lora_A", "lora_B"):
return value
idx = get_adapter_idx(path, adapter_index)
return value.at[idx].set(0 if key == "lora_ranks" else 0.0)

updated_state = jax.tree.map_with_path(clear_adapter, state)
nnx.update(model, updated_state)
Loading
Loading