1111
1212from ...layers .feedforward import FeedForward , ParallelFeedForward
1313from ...layers .linear import ColParallelLinear , Linear
14- from ...layers .module import Module
14+ from ...layers .module import Module , ModuleList , Parameter
1515from ...parallel .config import EncoderParallelConfig
1616from ...parallel .manager import CCLManager
17- from ...utils .substate import indexed_substates , substate
18- from ...utils .tensor import bf16_tensor
17+ from ...utils .substate import rename_substate
1918
2019
2120class CLIPConfig :
@@ -57,6 +56,7 @@ def __init__(
5756 layer_norm_eps : float = 1e-05 ,
5857 attention_dropout : float = 0.0 ,
5958 hidden_act : str = "quick_gelu" ,
59+ projection_dim : int | None = None ,
6060 ):
6161 self .vocab_size = vocab_size
6262 self .embed_dim = embed_dim
@@ -66,6 +66,7 @@ def __init__(
6666 self .max_prompt_length = max_prompt_length
6767 self .layer_norm_eps = layer_norm_eps
6868 self .attention_dropout = attention_dropout
69+ self .projection_dim = projection_dim
6970 if hidden_act == "gelu" :
7071 self .hidden_act = "decomposed_gelu"
7172 else :
@@ -97,36 +98,33 @@ def __init__(
9798 self .embeddings = TextEmbeddings (config , mesh_device )
9899 self .eos_token_id = eos_token_id
99100 self .encoder = CLIPStack (config , self .mesh_device , self .ccl_manager , self .parallel_config )
100- self .text_projection = None
101101
102- def load_torch_state_dict (self , state_dict ):
103- self .embeddings .load_torch_state_dict (substate (state_dict , "text_model.embeddings" ))
104- self .encoder .load_torch_state_dict (substate (state_dict , "text_model.encoder" ))
105-
106- self .final_layer_norm = bf16_tensor (
107- state_dict ["text_model.final_layer_norm.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
108- )
109- self .final_layer_norm_bias = bf16_tensor (
110- state_dict ["text_model.final_layer_norm.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
102+ self .final_layer_norm = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
103+ self .final_layer_norm_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
104+ self .text_projection = (
105+ Parameter (total_shape = [config .embed_dim , config .projection_dim ], device = mesh_device )
106+ if config .projection_dim is not None
107+ else None
111108 )
112- if "text_projection.weight" in state_dict :
113- self .text_projection = bf16_tensor (
114- state_dict ["text_projection.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
115- )
116- else :
117- self .text_projection = None
109+
110+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
111+ rename_substate (state , "text_model.embeddings" , "embeddings" )
112+ rename_substate (state , "text_model.encoder" , "encoder" )
113+
114+ if "text_model.final_layer_norm.weight" in state :
115+ state ["final_layer_norm" ] = state .pop ("text_model.final_layer_norm.weight" )
116+ if "text_model.final_layer_norm.bias" in state :
117+ state ["final_layer_norm_bias" ] = state .pop ("text_model.final_layer_norm.bias" )
118+ if "text_projection.weight" in state :
119+ state ["text_projection" ] = state .pop ("text_projection.weight" )
118120
119121 def forward (
120122 self ,
121123 prompt_tokenized : ttnn .Tensor ,
122124 mesh_device : ttnn .Device ,
123125 * ,
124- with_projection : bool | None = None ,
125126 return_normalized_state : bool = False ,
126127 ) -> tuple [ttnn .Tensor , ...]:
127- if with_projection is None :
128- with_projection = self .text_projection is not None
129-
130128 hidden_states = self .embeddings (prompt_tokenized , mesh_device )
131129
132130 causal_attention_mask = create_4d_causal_attention_mask (
@@ -142,8 +140,8 @@ def forward(
142140 final_hidden_layer = encoder_output [- 1 ] # final hidden layer
143141 normalized_final_state = ttnn .layer_norm ( # final layer norm
144142 final_hidden_layer ,
145- weight = self .final_layer_norm ,
146- bias = self .final_layer_norm_bias ,
143+ weight = self .final_layer_norm . data ,
144+ bias = self .final_layer_norm_bias . data ,
147145 epsilon = self .config .layer_norm_eps ,
148146 compute_kernel_config = self .compute_kernel_config ,
149147 )
@@ -160,11 +158,8 @@ def forward(
160158 ccl_manager = self .ccl_manager ,
161159 )
162160
163- # apply text projection if specified
164- if with_projection :
165- if self .text_projection is None :
166- raise ValueError ("projection weights are not loaded" )
167- text_projection_transposed = ttnn .transpose (self .text_projection , - 2 , - 1 )
161+ if self .text_projection is not None :
162+ text_projection_transposed = ttnn .transpose (self .text_projection .data , - 2 , - 1 )
168163 pooled_output = ttnn .matmul (
169164 pooled_output , text_projection_transposed , compute_kernel_config = self .compute_kernel_config
170165 )
@@ -253,14 +248,9 @@ def __init__(
253248 fp32_dest_acc_en = True ,
254249 packer_l1_acc = True ,
255250 )
256- self .layers = [
251+ self .layers = ModuleList (
257252 CLIPEncoderLayer (config , mesh_device , ccl_manager , parallel_config ) for _ in range (config .num_hidden_layers )
258- ]
259-
260- def load_torch_state_dict (self , state_dict ):
261- layer_states = indexed_substates (state_dict , "layers" )
262- for layer , layer_state in zip (self .layers , layer_states ):
263- layer .load_torch_state_dict (layer_state )
253+ )
264254
265255 def forward (
266256 self ,
@@ -298,8 +288,6 @@ def __init__(
298288 fp32_dest_acc_en = True ,
299289 packer_l1_acc = True ,
300290 )
301- self .layer_norm1 = None
302- self .layer_norm2 = None
303291 self .layer_norm_eps = config .layer_norm_eps
304292 self .self_attn = CLIPAttention (config , mesh_device , ccl_manager , parallel_config )
305293 self .parallel_config = parallel_config
@@ -321,31 +309,23 @@ def __init__(
321309 )
322310 self .ccl_manager = ccl_manager
323311
324- def load_torch_state_dict (self , state_dict ):
325- self .layer_norm1 = bf16_tensor (
326- state_dict ["layer_norm1.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
327- )
328- self .layer_norm1_bias = bf16_tensor (
329- state_dict ["layer_norm1.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
330- )
331- self .layer_norm2 = bf16_tensor (
332- state_dict ["layer_norm2.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
333- )
334- self .layer_norm2_bias = bf16_tensor (
335- state_dict ["layer_norm2.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
336- )
312+ self .layer_norm1 = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
313+ self .layer_norm1_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
314+ self .layer_norm2 = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
315+ self .layer_norm2_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
337316
338- self .self_attn .load_torch_state_dict (substate (state_dict , "self_attn" ))
317+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
318+ if "layer_norm1.weight" in state :
319+ state ["layer_norm1" ] = state .pop ("layer_norm1.weight" )
320+ if "layer_norm1.bias" in state :
321+ state ["layer_norm1_bias" ] = state .pop ("layer_norm1.bias" )
322+ if "layer_norm2.weight" in state :
323+ state ["layer_norm2" ] = state .pop ("layer_norm2.weight" )
324+ if "layer_norm2.bias" in state :
325+ state ["layer_norm2_bias" ] = state .pop ("layer_norm2.bias" )
339326
340- # remap MLP keys from fc1/fc2 to ff1/ff2 format
341- mlp_state = substate (state_dict , "mlp" )
342- remapped_mlp_state = {
343- "ff1.weight" : mlp_state ["fc1.weight" ],
344- "ff1.bias" : mlp_state ["fc1.bias" ],
345- "ff2.weight" : mlp_state ["fc2.weight" ],
346- "ff2.bias" : mlp_state ["fc2.bias" ],
347- }
348- self .mlp .load_torch_state_dict (remapped_mlp_state )
327+ rename_substate (state , "mlp.fc1" , "mlp.ff1" )
328+ rename_substate (state , "mlp.fc2" , "mlp.ff2" )
349329
350330 def forward (
351331 self ,
@@ -357,8 +337,8 @@ def forward(
357337 residual = hidden_states
358338 hidden_states = ttnn .layer_norm (
359339 hidden_states ,
360- weight = self .layer_norm1 ,
361- bias = self .layer_norm1_bias ,
340+ weight = self .layer_norm1 . data ,
341+ bias = self .layer_norm1_bias . data ,
362342 epsilon = self .layer_norm_eps ,
363343 compute_kernel_config = self .compute_kernel_config ,
364344 )
@@ -368,8 +348,8 @@ def forward(
368348 residual = hidden_states
369349 hidden_states = ttnn .layer_norm (
370350 hidden_states ,
371- weight = self .layer_norm2 ,
372- bias = self .layer_norm2_bias ,
351+ weight = self .layer_norm2 . data ,
352+ bias = self .layer_norm2_bias . data ,
373353 epsilon = self .layer_norm_eps ,
374354 compute_kernel_config = self .compute_kernel_config ,
375355 )
@@ -467,11 +447,8 @@ def __init__(
467447 self .v_proj = Linear (in_features = self .embed_dim , out_features = self .embed_dim , mesh_device = self .mesh_device )
468448 self .o_proj = Linear (in_features = self .embed_dim , out_features = self .embed_dim , mesh_device = self .mesh_device )
469449
470- def load_torch_state_dict (self , state_dict ):
471- self .q_proj .load_torch_state_dict (substate (state_dict , "q_proj" ))
472- self .k_proj .load_torch_state_dict (substate (state_dict , "k_proj" ))
473- self .v_proj .load_torch_state_dict (substate (state_dict , "v_proj" ))
474- self .o_proj .load_torch_state_dict (substate (state_dict , "out_proj" ))
450+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
451+ rename_substate (state , "out_proj" , "o_proj" )
475452
476453 def forward (self , hidden_states , causal_attention_mask ):
477454 batch_size , seq_length , _ = hidden_states .shape
@@ -570,32 +547,22 @@ def __init__(self, config, mesh_device: ttnn.Device) -> None:
570547 self .config = config
571548 self .mesh_device = mesh_device
572549
573- self .token_embedding = None
574- self .position_embedding = None
575-
576- def load_torch_state_dict (self , state_dict ):
577- self .token_embedding = bf16_tensor (
578- state_dict ["token_embedding.weight" ], device = self .mesh_device , layout = ttnn .ROW_MAJOR_LAYOUT
550+ self .token_embedding = Parameter (
551+ total_shape = [config .vocab_size , config .embed_dim ],
552+ device = mesh_device ,
553+ layout = ttnn .ROW_MAJOR_LAYOUT ,
579554 )
580- self .position_embedding = bf16_tensor (
581- state_dict ["position_embedding.weight" ], device = self .mesh_device , layout = ttnn .ROW_MAJOR_LAYOUT
555+ self .position_embedding = Parameter (
556+ total_shape = [config .max_prompt_length , config .embed_dim ],
557+ device = mesh_device ,
558+ layout = ttnn .ROW_MAJOR_LAYOUT ,
582559 )
583560
584- # TODO: Move to parameters to reuse module functionality
585- def to_cached_state_dict (self , path_prefix , path_suffix = ".tensorbin" ):
586- cache_dict = {}
587- token_embedding_weights_path = path_prefix + "token_embedding_weights" + path_suffix
588- position_embedding_weights_path = path_prefix + "position_embedding_weights" + path_suffix
589- ttnn .dump_tensor (token_embedding_weights_path , self .token_embedding )
590- ttnn .dump_tensor (position_embedding_weights_path , self .position_embedding )
591- cache_dict ["token_embedding" ] = token_embedding_weights_path
592- cache_dict ["position_embedding" ] = position_embedding_weights_path
593-
594- return cache_dict
595-
596- def from_cached_state_dict (self , cache_dict ):
597- self .token_embedding = ttnn .load_tensor (cache_dict ["token_embedding" ], device = self .mesh_device )
598- self .position_embedding = ttnn .load_tensor (cache_dict ["position_embedding" ], device = self .mesh_device )
561+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
562+ if "token_embedding.weight" in state :
563+ state ["token_embedding" ] = state .pop ("token_embedding.weight" )
564+ if "position_embedding.weight" in state :
565+ state ["position_embedding" ] = state .pop ("position_embedding.weight" )
599566
600567 def forward (self , prompt : ttnn .Tensor , device : ttnn .Device ) -> ttnn .Tensor :
601568 seq_len = prompt .shape [- 1 ]
@@ -604,11 +571,11 @@ def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
604571 prompt = prompt [:, : self .config .max_prompt_length ]
605572 seq_len = self .config .max_prompt_length
606573
607- input_embeddings = ttnn .embedding (prompt , self .token_embedding , layout = ttnn .TILE_LAYOUT )
574+ input_embeddings = ttnn .embedding (prompt , self .token_embedding . data , layout = ttnn .TILE_LAYOUT )
608575
609576 position_ids = torch .arange (seq_len ).expand ((1 , - 1 )) # shape: (1, seq_len)
610577 position_ids_ttnn = ttnn .from_torch (position_ids , dtype = ttnn .uint32 , layout = ttnn .TILE_LAYOUT , device = device )
611- position_embeddings = ttnn .embedding (position_ids_ttnn , self .position_embedding , layout = ttnn .TILE_LAYOUT )
578+ position_embeddings = ttnn .embedding (position_ids_ttnn , self .position_embedding . data , layout = ttnn .TILE_LAYOUT )
612579
613580 return input_embeddings + position_embeddings
614581
0 commit comments