Skip to content

Commit 1e63fa6

Browse files
committed
add LoRALayer (#472)
1 parent 281c0bd commit 1e63fa6

File tree

5 files changed

+166
-228
lines changed

5 files changed

+166
-228
lines changed

python/sgl_jax/srt/lora/backend/base_backend.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import jax
22

3-
from sgl_jax.srt.lora.utils import LoRABatchInfo
43
from sgl_jax.srt.model_executor.forward_batch_info import ForwardBatch
54

65

@@ -94,7 +93,6 @@ def prepare_lora_batch(
9493
weight_indices: list[int],
9594
lora_ranks: list[int],
9695
scalings: list[float],
97-
batch_info: LoRABatchInfo | None = None,
9896
):
9997
"""Prepare the lora weights and batch info for current forward batch.
10098

python/sgl_jax/srt/lora/backend/bgmv_backend.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ def prepare_lora_batch(
165165
weight_indices: list[int],
166166
lora_ranks: list[int],
167167
scalings: list[float],
168-
batch_info: LoRABatchInfo | None = None,
169168
):
170169
lora_ranks_bs = []
171170
scalings_bs = []

python/sgl_jax/srt/lora/layers.py

Lines changed: 68 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -21,103 +21,103 @@
2121
import jax
2222
from flax import nnx
2323

24+
from python.sgl_jax.srt.layers.linear import LinearBase
25+
2426
if TYPE_CHECKING:
2527
from sgl_jax.srt.lora.backend.base_backend import BaseLoRABackend
2628

2729

28-
class LoRALinear(nnx.Module):
30+
class BaseLayerWithLoRA(nnx.Module):
31+
def __init__(
32+
self,
33+
base_layer: nnx.Module,
34+
lora_backend: BaseLoRABackend,
35+
):
36+
super().__init__()
37+
self.base_layer: nnx.Module = base_layer
38+
self.set_lora: bool = False
39+
self.lora_backend: BaseLoRABackend = lora_backend
40+
if hasattr(self.base_layer, "weight"):
41+
self.weight = self.base_layer.weight
42+
43+
def __call__(self, x: jax.Array):
44+
return self.base_layer(x)
45+
46+
def set_lora_info(self, *args):
47+
pass
48+
49+
50+
class LoRALinear(BaseLayerWithLoRA):
2951
"""
3052
LoRA wrapper for Linear layers using Flax NNX.
31-
3253
This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation)
3354
computation. Uses Model Surgery to preserve the original weights and sharding.
3455
35-
V1 implementation uses backend to perform LoRA computation:
36-
output = base_layer(x)
37-
if enabled:
38-
lora_output = backend.run_lora_a_gemm(x, lora_A_weights)
39-
output = backend.run_lora_b_gemm(lora_output, lora_B_weights, output)
40-
4156
Attributes:
4257
base_layer: Original Linear layer (preserves weights and sharding)
43-
lora_rank: LoRA rank dimension
4458
backend: LoRA backend for efficient computation
45-
enabled: Whether LoRA computation is active
4659
"""
4760

4861
def __init__(
4962
self,
50-
in_features: int,
51-
out_features: int,
52-
lora_rank: int,
53-
base_layer: nnx.Linear | None = None,
54-
backend: BaseLoRABackend | None = None,
55-
rngs: nnx.Rngs | None = None,
63+
base_layer: LinearBase | None = None,
64+
lora_backend: BaseLoRABackend | None = None,
5665
):
5766
"""
5867
Initialize LoRA Linear layer.
5968
6069
Args:
61-
in_features: Input dimension
62-
out_features: Output dimension
63-
lora_rank: Rank of LoRA matrices
64-
base_layer: Existing Linear layer to wrap (optional)
65-
backend: LoRA backend for computation (optional)
66-
rngs: Random number generators for initialization
70+
base_layer: Existing Linear layer to wrap
71+
backend: LoRA backend for computation
6772
"""
68-
self.in_features = in_features
69-
self.out_features = out_features
70-
self.lora_rank = lora_rank
71-
self.backend = backend
72-
73-
# Base layer - will be populated via nnx.update() during surgery
74-
if base_layer is not None:
75-
self.base_layer = base_layer
76-
else:
77-
# Create placeholder base layer
78-
if rngs is None:
79-
rngs = nnx.Rngs(0)
80-
self.base_layer = nnx.Linear(
81-
in_features,
82-
out_features,
83-
use_bias=True,
84-
rngs=rngs,
85-
)
86-
87-
# Control variable (not trainable)
88-
self.enabled = nnx.Variable(False) # Whether LoRA is active
89-
90-
def __call__(self, x: jax.Array) -> jax.Array:
73+
super().__init__(base_layer, lora_backend)
74+
self.lora_backend = lora_backend
75+
76+
def set_lora_info(
77+
self,
78+
A_buffer: jax.Array,
79+
B_buffer: jax.Array,
80+
):
81+
self.set_lora = True
82+
self.A_buffer = A_buffer
83+
self.B_buffer = B_buffer
84+
85+
def apply_lora(self, base_output: jax.Array, x: jax.Array) -> jax.Array:
86+
lora_a_output = self.lora_backend.run_lora_a_gemm(x, self.A_buffer)
87+
lora_output = self.lora_backend.run_lora_b_gemm(
88+
x=lora_a_output,
89+
weights=self.B_buffer,
90+
base_output=base_output,
91+
)
92+
return lora_output
93+
94+
def __call__(self, x: jax.Array) -> tuple[jax.Array, jax.Array | None]:
9195
"""
9296
Forward pass with optional LoRA computation using backend.
9397
9498
Args:
9599
x: Input tensor (shape: [seq_len, in_features])
96100
97101
Returns:
98-
Output tensor with LoRA delta added (if enabled)
102+
Output tensor with LoRA delta added (if enabled) and bias from base_model
99103
"""
100104
# Base layer computation (preserves original behavior)
101-
output = self.base_layer(x)
102-
103-
# Add LoRA delta if enabled and backend is available
104-
if self.enabled.value and self.backend is not None:
105-
# Get LoRA weights from memory pool via backend
106-
# Backend handles batched LoRA computation for multiple adapters
107-
108-
# Step 1: Shrink - project to low-rank space
109-
# lora_A_weights fetched from memory pool based on batch_info
110-
lora_a_output = self.backend.run_lora_a_gemm(
111-
x, None
112-
) # Backend manages weights internally
105+
output_bias = jax.lax.cond(
106+
not self.base_layer.skip_bias_add,
107+
lambda operands: operands[0],
108+
lambda operands: None,
109+
(self.base_layer.bias),
110+
)
111+
base_output = self.base_layer(x)
113112

114-
# Step 2: Expand - project back to output space and add to base output
115-
output = self.backend.run_lora_b_gemm(lora_a_output, None, output)
113+
output = jax.lax.cond(
114+
self.set_lora, self.apply_lora, lambda operands: operands[0], (base_output, x)
115+
)
116116

117-
return output
117+
return output, output_bias
118118

119119

120-
class LoRAEmbedding(nnx.Module):
120+
class LoRAEmbedding(BaseLayerWithLoRA):
121121
"""
122122
LoRA wrapper for Embedding layers.
123123
@@ -127,61 +127,15 @@ class LoRAEmbedding(nnx.Module):
127127

128128
def __init__(
129129
self,
130-
num_embeddings: int,
131-
features: int,
132-
lora_rank: int,
133-
base_layer: nnx.Embed | None = None,
134-
backend: BaseLoRABackend | None = None,
135-
rngs: nnx.Rngs | None = None,
130+
base_layer: LinearBase | None = None,
131+
lora_backend: BaseLoRABackend | None = None,
136132
):
137133
"""
138134
Initialize LoRA Embedding layer.
139135
140136
Args:
141-
num_embeddings: Size of vocabulary
142-
features: Embedding dimension
143-
lora_rank: Rank of LoRA matrices
144-
base_layer: Existing Embed layer to wrap (optional)
145-
backend: LoRA backend for computation (optional)
146-
rngs: Random number generators
137+
base_layer: Existing Embed layer to wrap
138+
backend: LoRA backend for computation
147139
"""
148-
self.num_embeddings = num_embeddings
149-
self.features = features
150-
self.lora_rank = lora_rank
151-
self.backend = backend
152-
153-
# Base layer
154-
if base_layer is not None:
155-
self.base_layer = base_layer
156-
else:
157-
if rngs is None:
158-
rngs = nnx.Rngs(0)
159-
self.base_layer = nnx.Embed(
160-
num_embeddings,
161-
features,
162-
rngs=rngs,
163-
)
164-
165-
# Control variable
166-
self.enabled = nnx.Variable(False)
167-
168-
def __call__(self, x: jax.Array) -> jax.Array:
169-
"""
170-
Forward pass for embedding with LoRA using backend.
171-
172-
Args:
173-
x: Input token indices
174-
175-
Returns:
176-
Embedded output with LoRA delta (if enabled)
177-
"""
178-
output = self.base_layer(x)
179-
180-
# V1: Embedding LoRA computation via backend
181-
# TODO: Implement embedding-specific backend methods if needed
182-
# For now, embeddings use simple pass-through
183-
if self.enabled.value and self.backend is not None:
184-
# Backend handles embedding LoRA computation
185-
pass
186-
187-
return output
140+
super().__init__(base_layer, lora_backend)
141+
self.weight = base_layer.weight

0 commit comments

Comments
 (0)