66
77import torch
88import ttnn
9+ from ttnn .distributed .distributed import ConcatMeshToTensor
910
10- from ...utils .tensor import bf16_tensor
11- from ...utils .substate import substate , indexed_substates
12- from ...parallel .manager import CCLManager
13- from ...parallel .config import EncoderParallelConfig
14- from ...layers .feedforward import ParallelFeedForward , FeedForward
11+ from ...layers .feedforward import FeedForward , ParallelFeedForward
1512from ...layers .linear import ColParallelLinear , Linear
16- from ttnn .distributed .distributed import ConcatMeshToTensor
17- from ...layers .module import Module
13+ from ...layers .module import Module , ModuleList , Parameter
14+ from ...parallel .config import EncoderParallelConfig
15+ from ...parallel .manager import CCLManager
16+ from ...utils .substate import rename_substate
1817
1918
2019class CLIPConfig :
@@ -56,6 +55,7 @@ def __init__(
5655 layer_norm_eps : float = 1e-05 ,
5756 attention_dropout : float = 0.0 ,
5857 hidden_act : str = "quick_gelu" ,
58+ projection_dim : int | None = None ,
5959 ):
6060 self .vocab_size = vocab_size
6161 self .embed_dim = embed_dim
@@ -65,6 +65,7 @@ def __init__(
6565 self .max_prompt_length = max_prompt_length
6666 self .layer_norm_eps = layer_norm_eps
6767 self .attention_dropout = attention_dropout
68+ self .projection_dim = projection_dim
6869 if hidden_act == "gelu" :
6970 self .hidden_act = "decomposed_gelu"
7071 else :
@@ -96,36 +97,33 @@ def __init__(
9697 self .embeddings = TextEmbeddings (config , mesh_device )
9798 self .eos_token_id = eos_token_id
9899 self .encoder = CLIPStack (config , self .mesh_device , self .ccl_manager , self .parallel_config )
99- self .text_projection = None
100100
101- def load_torch_state_dict (self , state_dict ):
102- self .embeddings .load_torch_state_dict (substate (state_dict , "text_model.embeddings" ))
103- self .encoder .load_torch_state_dict (substate (state_dict , "text_model.encoder" ))
104-
105- self .final_layer_norm = bf16_tensor (
106- state_dict ["text_model.final_layer_norm.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
107- )
108- self .final_layer_norm_bias = bf16_tensor (
109- state_dict ["text_model.final_layer_norm.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
101+ self .final_layer_norm = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
102+ self .final_layer_norm_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
103+ self .text_projection = (
104+ Parameter (total_shape = [config .embed_dim , config .projection_dim ], device = mesh_device )
105+ if config .projection_dim is not None
106+ else None
110107 )
111- if "text_projection.weight" in state_dict :
112- self .text_projection = bf16_tensor (
113- state_dict ["text_projection.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
114- )
115- else :
116- self .text_projection = None
108+
109+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
110+ rename_substate (state , "text_model.embeddings" , "embeddings" )
111+ rename_substate (state , "text_model.encoder" , "encoder" )
112+
113+ if "text_model.final_layer_norm.weight" in state :
114+ state ["final_layer_norm" ] = state .pop ("text_model.final_layer_norm.weight" )
115+ if "text_model.final_layer_norm.bias" in state :
116+ state ["final_layer_norm_bias" ] = state .pop ("text_model.final_layer_norm.bias" )
117+ if "text_projection.weight" in state :
118+ state ["text_projection" ] = state .pop ("text_projection.weight" )
117119
118120 def forward (
119121 self ,
120122 prompt_tokenized : ttnn .Tensor ,
121123 mesh_device : ttnn .Device ,
122124 * ,
123- with_projection : bool | None = None ,
124125 return_normalized_state : bool = False ,
125126 ) -> tuple [ttnn .Tensor , ...]:
126- if with_projection is None :
127- with_projection = self .text_projection is not None
128-
129127 hidden_states = self .embeddings (prompt_tokenized , mesh_device )
130128
131129 causal_attention_mask = create_4d_causal_attention_mask (
@@ -141,8 +139,8 @@ def forward(
141139 final_hidden_layer = encoder_output [- 1 ] # final hidden layer
142140 normalized_final_state = ttnn .layer_norm ( # final layer norm
143141 final_hidden_layer ,
144- weight = self .final_layer_norm ,
145- bias = self .final_layer_norm_bias ,
142+ weight = self .final_layer_norm . data ,
143+ bias = self .final_layer_norm_bias . data ,
146144 epsilon = self .config .layer_norm_eps ,
147145 compute_kernel_config = self .compute_kernel_config ,
148146 )
@@ -159,11 +157,8 @@ def forward(
159157 ccl_manager = self .ccl_manager ,
160158 )
161159
162- # apply text projection if specified
163- if with_projection :
164- if self .text_projection is None :
165- raise ValueError ("projection weights are not loaded" )
166- text_projection_transposed = ttnn .transpose (self .text_projection , - 2 , - 1 )
160+ if self .text_projection is not None :
161+ text_projection_transposed = ttnn .transpose (self .text_projection .data , - 2 , - 1 )
167162 pooled_output = ttnn .matmul (
168163 pooled_output , text_projection_transposed , compute_kernel_config = self .compute_kernel_config
169164 )
@@ -252,14 +247,9 @@ def __init__(
252247 fp32_dest_acc_en = True ,
253248 packer_l1_acc = True ,
254249 )
255- self .layers = [
250+ self .layers = ModuleList (
256251 CLIPEncoderLayer (config , mesh_device , ccl_manager , parallel_config ) for _ in range (config .num_hidden_layers )
257- ]
258-
259- def load_torch_state_dict (self , state_dict ):
260- layer_states = indexed_substates (state_dict , "layers" )
261- for layer , layer_state in zip (self .layers , layer_states ):
262- layer .load_torch_state_dict (layer_state )
252+ )
263253
264254 def forward (
265255 self ,
@@ -297,8 +287,6 @@ def __init__(
297287 fp32_dest_acc_en = True ,
298288 packer_l1_acc = True ,
299289 )
300- self .layer_norm1 = None
301- self .layer_norm2 = None
302290 self .layer_norm_eps = config .layer_norm_eps
303291 self .self_attn = CLIPAttention (config , mesh_device , ccl_manager , parallel_config )
304292 self .parallel_config = parallel_config
@@ -320,31 +308,23 @@ def __init__(
320308 )
321309 self .ccl_manager = ccl_manager
322310
323- def load_torch_state_dict (self , state_dict ):
324- self .layer_norm1 = bf16_tensor (
325- state_dict ["layer_norm1.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
326- )
327- self .layer_norm1_bias = bf16_tensor (
328- state_dict ["layer_norm1.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
329- )
330- self .layer_norm2 = bf16_tensor (
331- state_dict ["layer_norm2.weight" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
332- )
333- self .layer_norm2_bias = bf16_tensor (
334- state_dict ["layer_norm2.bias" ], device = self .mesh_device , layout = ttnn .TILE_LAYOUT
335- )
311+ self .layer_norm1 = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
312+ self .layer_norm1_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
313+ self .layer_norm2 = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
314+ self .layer_norm2_bias = Parameter (total_shape = [config .embed_dim ], device = mesh_device )
336315
337- self .self_attn .load_torch_state_dict (substate (state_dict , "self_attn" ))
316+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
317+ if "layer_norm1.weight" in state :
318+ state ["layer_norm1" ] = state .pop ("layer_norm1.weight" )
319+ if "layer_norm1.bias" in state :
320+ state ["layer_norm1_bias" ] = state .pop ("layer_norm1.bias" )
321+ if "layer_norm2.weight" in state :
322+ state ["layer_norm2" ] = state .pop ("layer_norm2.weight" )
323+ if "layer_norm2.bias" in state :
324+ state ["layer_norm2_bias" ] = state .pop ("layer_norm2.bias" )
338325
339- # remap MLP keys from fc1/fc2 to ff1/ff2 format
340- mlp_state = substate (state_dict , "mlp" )
341- remapped_mlp_state = {
342- "ff1.weight" : mlp_state ["fc1.weight" ],
343- "ff1.bias" : mlp_state ["fc1.bias" ],
344- "ff2.weight" : mlp_state ["fc2.weight" ],
345- "ff2.bias" : mlp_state ["fc2.bias" ],
346- }
347- self .mlp .load_torch_state_dict (remapped_mlp_state )
326+ rename_substate (state , "mlp.fc1" , "mlp.ff1" )
327+ rename_substate (state , "mlp.fc2" , "mlp.ff2" )
348328
349329 def forward (
350330 self ,
@@ -356,8 +336,8 @@ def forward(
356336 residual = hidden_states
357337 hidden_states = ttnn .layer_norm (
358338 hidden_states ,
359- weight = self .layer_norm1 ,
360- bias = self .layer_norm1_bias ,
339+ weight = self .layer_norm1 . data ,
340+ bias = self .layer_norm1_bias . data ,
361341 epsilon = self .layer_norm_eps ,
362342 compute_kernel_config = self .compute_kernel_config ,
363343 )
@@ -367,8 +347,8 @@ def forward(
367347 residual = hidden_states
368348 hidden_states = ttnn .layer_norm (
369349 hidden_states ,
370- weight = self .layer_norm2 ,
371- bias = self .layer_norm2_bias ,
350+ weight = self .layer_norm2 . data ,
351+ bias = self .layer_norm2_bias . data ,
372352 epsilon = self .layer_norm_eps ,
373353 compute_kernel_config = self .compute_kernel_config ,
374354 )
@@ -466,11 +446,8 @@ def __init__(
466446 self .v_proj = Linear (in_features = self .embed_dim , out_features = self .embed_dim , mesh_device = self .mesh_device )
467447 self .o_proj = Linear (in_features = self .embed_dim , out_features = self .embed_dim , mesh_device = self .mesh_device )
468448
469- def load_torch_state_dict (self , state_dict ):
470- self .q_proj .load_torch_state_dict (substate (state_dict , "q_proj" ))
471- self .k_proj .load_torch_state_dict (substate (state_dict , "k_proj" ))
472- self .v_proj .load_torch_state_dict (substate (state_dict , "v_proj" ))
473- self .o_proj .load_torch_state_dict (substate (state_dict , "out_proj" ))
449+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
450+ rename_substate (state , "out_proj" , "o_proj" )
474451
475452 def forward (self , hidden_states , causal_attention_mask ):
476453 batch_size , seq_length , _ = hidden_states .shape
@@ -569,32 +546,22 @@ def __init__(self, config, mesh_device: ttnn.Device) -> None:
569546 self .config = config
570547 self .mesh_device = mesh_device
571548
572- self .token_embedding = None
573- self .position_embedding = None
574-
575- def load_torch_state_dict (self , state_dict ):
576- self .token_embedding = bf16_tensor (
577- state_dict ["token_embedding.weight" ], device = self .mesh_device , layout = ttnn .ROW_MAJOR_LAYOUT
549+ self .token_embedding = Parameter (
550+ total_shape = [config .vocab_size , config .embed_dim ],
551+ device = mesh_device ,
552+ layout = ttnn .ROW_MAJOR_LAYOUT ,
578553 )
579- self .position_embedding = bf16_tensor (
580- state_dict ["position_embedding.weight" ], device = self .mesh_device , layout = ttnn .ROW_MAJOR_LAYOUT
554+ self .position_embedding = Parameter (
555+ total_shape = [config .max_prompt_length , config .embed_dim ],
556+ device = mesh_device ,
557+ layout = ttnn .ROW_MAJOR_LAYOUT ,
581558 )
582559
583- # TODO: Move to parameters to reuse module functionality
584- def to_cached_state_dict (self , path_prefix , path_suffix = ".tensorbin" ):
585- cache_dict = {}
586- token_embedding_weights_path = path_prefix + "token_embedding_weights" + path_suffix
587- position_embedding_weights_path = path_prefix + "position_embedding_weights" + path_suffix
588- ttnn .dump_tensor (token_embedding_weights_path , self .token_embedding )
589- ttnn .dump_tensor (position_embedding_weights_path , self .position_embedding )
590- cache_dict ["token_embedding" ] = token_embedding_weights_path
591- cache_dict ["position_embedding" ] = position_embedding_weights_path
592-
593- return cache_dict
594-
595- def from_cached_state_dict (self , cache_dict ):
596- self .token_embedding = ttnn .load_tensor (cache_dict ["token_embedding" ], device = self .mesh_device )
597- self .position_embedding = ttnn .load_tensor (cache_dict ["position_embedding" ], device = self .mesh_device )
560+ def _prepare_torch_state (self , state : dict [str , torch .Tensor ]) -> None :
561+ if "token_embedding.weight" in state :
562+ state ["token_embedding" ] = state .pop ("token_embedding.weight" )
563+ if "position_embedding.weight" in state :
564+ state ["position_embedding" ] = state .pop ("position_embedding.weight" )
598565
599566 def forward (self , prompt : ttnn .Tensor , device : ttnn .Device ) -> ttnn .Tensor :
600567 seq_len = prompt .shape [- 1 ]
@@ -603,11 +570,11 @@ def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
603570 prompt = prompt [:, : self .config .max_prompt_length ]
604571 seq_len = self .config .max_prompt_length
605572
606- input_embeddings = ttnn .embedding (prompt , self .token_embedding , layout = ttnn .TILE_LAYOUT )
573+ input_embeddings = ttnn .embedding (prompt , self .token_embedding . data , layout = ttnn .TILE_LAYOUT )
607574
608575 position_ids = torch .arange (seq_len ).expand ((1 , - 1 )) # shape: (1, seq_len)
609576 position_ids_ttnn = ttnn .from_torch (position_ids , dtype = ttnn .uint32 , layout = ttnn .TILE_LAYOUT , device = device )
610- position_embeddings = ttnn .embedding (position_ids_ttnn , self .position_embedding , layout = ttnn .TILE_LAYOUT )
577+ position_embeddings = ttnn .embedding (position_ids_ttnn , self .position_embedding . data , layout = ttnn .TILE_LAYOUT )
611578
612579 return input_embeddings + position_embeddings
613580
0 commit comments