2121import jax
2222from flax import nnx
2323
24+ from python .sgl_jax .srt .layers .linear import LinearBase
25+
2426if TYPE_CHECKING :
2527 from sgl_jax .srt .lora .backend .base_backend import BaseLoRABackend
2628
2729
28- class LoRALinear (nnx .Module ):
30+ class BaseLayerWithLoRA (nnx .Module ):
31+ def __init__ (
32+ self ,
33+ base_layer : nnx .Module ,
34+ lora_backend : BaseLoRABackend ,
35+ ):
36+ super ().__init__ ()
37+ self .base_layer : nnx .Module = base_layer
38+ self .set_lora : bool = False
39+ self .lora_backend : BaseLoRABackend = lora_backend
40+ if hasattr (self .base_layer , "weight" ):
41+ self .weight = self .base_layer .weight
42+
43+ def __call__ (self , x : jax .Array ):
44+ return self .base_layer (x )
45+
46+ def set_lora_info (self , * args ):
47+ pass
48+
49+
50+ class LoRALinear (BaseLayerWithLoRA ):
2951 """
3052 LoRA wrapper for Linear layers using Flax NNX.
31-
3253 This wraps an existing Linear layer and adds LoRA (Low-Rank Adaptation)
3354 computation. Uses Model Surgery to preserve the original weights and sharding.
3455
35- V1 implementation uses backend to perform LoRA computation:
36- output = base_layer(x)
37- if enabled:
38- lora_output = backend.run_lora_a_gemm(x, lora_A_weights)
39- output = backend.run_lora_b_gemm(lora_output, lora_B_weights, output)
40-
4156 Attributes:
4257 base_layer: Original Linear layer (preserves weights and sharding)
43- lora_rank: LoRA rank dimension
4458 backend: LoRA backend for efficient computation
45- enabled: Whether LoRA computation is active
4659 """
4760
4861 def __init__ (
4962 self ,
50- in_features : int ,
51- out_features : int ,
52- lora_rank : int ,
53- base_layer : nnx .Linear | None = None ,
54- backend : BaseLoRABackend | None = None ,
55- rngs : nnx .Rngs | None = None ,
63+ base_layer : LinearBase | None = None ,
64+ lora_backend : BaseLoRABackend | None = None ,
5665 ):
5766 """
5867 Initialize LoRA Linear layer.
5968
6069 Args:
61- in_features: Input dimension
62- out_features: Output dimension
63- lora_rank: Rank of LoRA matrices
64- base_layer: Existing Linear layer to wrap (optional)
65- backend: LoRA backend for computation (optional)
66- rngs: Random number generators for initialization
70+ base_layer: Existing Linear layer to wrap
71+ backend: LoRA backend for computation
6772 """
68- self .in_features = in_features
69- self .out_features = out_features
70- self .lora_rank = lora_rank
71- self .backend = backend
72-
73- # Base layer - will be populated via nnx.update() during surgery
74- if base_layer is not None :
75- self .base_layer = base_layer
76- else :
77- # Create placeholder base layer
78- if rngs is None :
79- rngs = nnx .Rngs (0 )
80- self .base_layer = nnx .Linear (
81- in_features ,
82- out_features ,
83- use_bias = True ,
84- rngs = rngs ,
85- )
86-
87- # Control variable (not trainable)
88- self .enabled = nnx .Variable (False ) # Whether LoRA is active
89-
90- def __call__ (self , x : jax .Array ) -> jax .Array :
73+ super ().__init__ (base_layer , lora_backend )
74+ self .lora_backend = lora_backend
75+
76+ def set_lora_info (
77+ self ,
78+ A_buffer : jax .Array ,
79+ B_buffer : jax .Array ,
80+ ):
81+ self .set_lora = True
82+ self .A_buffer = A_buffer
83+ self .B_buffer = B_buffer
84+
85+ def apply_lora (self , base_output : jax .Array , x : jax .Array ) -> jax .Array :
86+ lora_a_output = self .lora_backend .run_lora_a_gemm (x , self .A_buffer )
87+ lora_output = self .lora_backend .run_lora_b_gemm (
88+ x = lora_a_output ,
89+ weights = self .B_buffer ,
90+ base_output = base_output ,
91+ )
92+ return lora_output
93+
94+ def __call__ (self , x : jax .Array ) -> tuple [jax .Array , jax .Array | None ]:
9195 """
9296 Forward pass with optional LoRA computation using backend.
9397
9498 Args:
9599 x: Input tensor (shape: [seq_len, in_features])
96100
97101 Returns:
98- Output tensor with LoRA delta added (if enabled)
102+ Output tensor with LoRA delta added (if enabled) and bias from base_model
99103 """
100104 # Base layer computation (preserves original behavior)
101- output = self .base_layer (x )
102-
103- # Add LoRA delta if enabled and backend is available
104- if self .enabled .value and self .backend is not None :
105- # Get LoRA weights from memory pool via backend
106- # Backend handles batched LoRA computation for multiple adapters
107-
108- # Step 1: Shrink - project to low-rank space
109- # lora_A_weights fetched from memory pool based on batch_info
110- lora_a_output = self .backend .run_lora_a_gemm (
111- x , None
112- ) # Backend manages weights internally
105+ output_bias = jax .lax .cond (
106+ not self .base_layer .skip_bias_add ,
107+ lambda operands : operands [0 ],
108+ lambda operands : None ,
109+ (self .base_layer .bias ),
110+ )
111+ base_output = self .base_layer (x )
113112
114- # Step 2: Expand - project back to output space and add to base output
115- output = self .backend .run_lora_b_gemm (lora_a_output , None , output )
113+ output = jax .lax .cond (
114+ self .set_lora , self .apply_lora , lambda operands : operands [0 ], (base_output , x )
115+ )
116116
117- return output
117+ return output , output_bias
118118
119119
120- class LoRAEmbedding (nnx . Module ):
120+ class LoRAEmbedding (BaseLayerWithLoRA ):
121121 """
122122 LoRA wrapper for Embedding layers.
123123
@@ -127,61 +127,15 @@ class LoRAEmbedding(nnx.Module):
127127
128128 def __init__ (
129129 self ,
130- num_embeddings : int ,
131- features : int ,
132- lora_rank : int ,
133- base_layer : nnx .Embed | None = None ,
134- backend : BaseLoRABackend | None = None ,
135- rngs : nnx .Rngs | None = None ,
130+ base_layer : LinearBase | None = None ,
131+ lora_backend : BaseLoRABackend | None = None ,
136132 ):
137133 """
138134 Initialize LoRA Embedding layer.
139135
140136 Args:
141- num_embeddings: Size of vocabulary
142- features: Embedding dimension
143- lora_rank: Rank of LoRA matrices
144- base_layer: Existing Embed layer to wrap (optional)
145- backend: LoRA backend for computation (optional)
146- rngs: Random number generators
137+ base_layer: Existing Embed layer to wrap
138+ backend: LoRA backend for computation
147139 """
148- self .num_embeddings = num_embeddings
149- self .features = features
150- self .lora_rank = lora_rank
151- self .backend = backend
152-
153- # Base layer
154- if base_layer is not None :
155- self .base_layer = base_layer
156- else :
157- if rngs is None :
158- rngs = nnx .Rngs (0 )
159- self .base_layer = nnx .Embed (
160- num_embeddings ,
161- features ,
162- rngs = rngs ,
163- )
164-
165- # Control variable
166- self .enabled = nnx .Variable (False )
167-
168- def __call__ (self , x : jax .Array ) -> jax .Array :
169- """
170- Forward pass for embedding with LoRA using backend.
171-
172- Args:
173- x: Input token indices
174-
175- Returns:
176- Embedded output with LoRA delta (if enabled)
177- """
178- output = self .base_layer (x )
179-
180- # V1: Embedding LoRA computation via backend
181- # TODO: Implement embedding-specific backend methods if needed
182- # For now, embeddings use simple pass-through
183- if self .enabled .value and self .backend is not None :
184- # Backend handles embedding LoRA computation
185- pass
186-
187- return output
140+ super ().__init__ (base_layer , lora_backend )
141+ self .weight = base_layer .weight
0 commit comments