Skip to content

Commit db36c85

Browse files
committed
migrate all models to make full use of the Module base class
1 parent ffc9776 commit db36c85

36 files changed

+1668
-2169
lines changed

models/experimental/tt_dit/encoders/clip/encoder_pair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _load_encoder(self, checkpoint: str, *, use_torch: bool) -> CLIPEncoder | CL
7373
eos_token_id=2, # default EOS token ID for CLIP
7474
)
7575

76-
model.load_state_dict(torch_model.state_dict())
76+
model.load_torch_state_dict(torch_model.state_dict())
7777

7878
return model
7979

models/experimental/tt_dit/encoders/clip/model_clip.py

Lines changed: 67 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66

77
import torch
88
import ttnn
9+
from ttnn.distributed.distributed import ConcatMeshToTensor
910

10-
from ...utils.tensor import bf16_tensor
11-
from ...utils.substate import substate, indexed_substates
12-
from ...parallel.manager import CCLManager
13-
from ...parallel.config import EncoderParallelConfig
14-
from ...layers.feedforward import ParallelFeedForward, FeedForward
11+
from ...layers.feedforward import FeedForward, ParallelFeedForward
1512
from ...layers.linear import ColParallelLinear, Linear
16-
from ttnn.distributed.distributed import ConcatMeshToTensor
17-
from ...layers.module import Module
13+
from ...layers.module import Module, ModuleList, Parameter
14+
from ...parallel.config import EncoderParallelConfig
15+
from ...parallel.manager import CCLManager
16+
from ...utils.substate import rename_substate
1817

1918

2019
class CLIPConfig:
@@ -56,6 +55,7 @@ def __init__(
5655
layer_norm_eps: float = 1e-05,
5756
attention_dropout: float = 0.0,
5857
hidden_act: str = "quick_gelu",
58+
projection_dim: int | None = None,
5959
):
6060
self.vocab_size = vocab_size
6161
self.embed_dim = embed_dim
@@ -65,6 +65,7 @@ def __init__(
6565
self.max_prompt_length = max_prompt_length
6666
self.layer_norm_eps = layer_norm_eps
6767
self.attention_dropout = attention_dropout
68+
self.projection_dim = projection_dim
6869
if hidden_act == "gelu":
6970
self.hidden_act = "decomposed_gelu"
7071
else:
@@ -96,36 +97,33 @@ def __init__(
9697
self.embeddings = TextEmbeddings(config, mesh_device)
9798
self.eos_token_id = eos_token_id
9899
self.encoder = CLIPStack(config, self.mesh_device, self.ccl_manager, self.parallel_config)
99-
self.text_projection = None
100100

101-
def load_torch_state_dict(self, state_dict):
102-
self.embeddings.load_torch_state_dict(substate(state_dict, "text_model.embeddings"))
103-
self.encoder.load_torch_state_dict(substate(state_dict, "text_model.encoder"))
104-
105-
self.final_layer_norm = bf16_tensor(
106-
state_dict["text_model.final_layer_norm.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
107-
)
108-
self.final_layer_norm_bias = bf16_tensor(
109-
state_dict["text_model.final_layer_norm.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
101+
self.final_layer_norm = Parameter(total_shape=[config.embed_dim], device=mesh_device)
102+
self.final_layer_norm_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
103+
self.text_projection = (
104+
Parameter(total_shape=[config.embed_dim, config.projection_dim], device=mesh_device)
105+
if config.projection_dim is not None
106+
else None
110107
)
111-
if "text_projection.weight" in state_dict:
112-
self.text_projection = bf16_tensor(
113-
state_dict["text_projection.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
114-
)
115-
else:
116-
self.text_projection = None
108+
109+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
110+
rename_substate(state, "text_model.embeddings", "embeddings")
111+
rename_substate(state, "text_model.encoder", "encoder")
112+
113+
if "text_model.final_layer_norm.weight" in state:
114+
state["final_layer_norm"] = state.pop("text_model.final_layer_norm.weight")
115+
if "text_model.final_layer_norm.bias" in state:
116+
state["final_layer_norm_bias"] = state.pop("text_model.final_layer_norm.bias")
117+
if "text_projection.weight" in state:
118+
state["text_projection"] = state.pop("text_projection.weight")
117119

118120
def forward(
119121
self,
120122
prompt_tokenized: ttnn.Tensor,
121123
mesh_device: ttnn.Device,
122124
*,
123-
with_projection: bool | None = None,
124125
return_normalized_state: bool = False,
125126
) -> tuple[ttnn.Tensor, ...]:
126-
if with_projection is None:
127-
with_projection = self.text_projection is not None
128-
129127
hidden_states = self.embeddings(prompt_tokenized, mesh_device)
130128

131129
causal_attention_mask = create_4d_causal_attention_mask(
@@ -141,8 +139,8 @@ def forward(
141139
final_hidden_layer = encoder_output[-1] # final hidden layer
142140
normalized_final_state = ttnn.layer_norm( # final layer norm
143141
final_hidden_layer,
144-
weight=self.final_layer_norm,
145-
bias=self.final_layer_norm_bias,
142+
weight=self.final_layer_norm.data,
143+
bias=self.final_layer_norm_bias.data,
146144
epsilon=self.config.layer_norm_eps,
147145
compute_kernel_config=self.compute_kernel_config,
148146
)
@@ -159,11 +157,8 @@ def forward(
159157
ccl_manager=self.ccl_manager,
160158
)
161159

162-
# apply text projection if specified
163-
if with_projection:
164-
if self.text_projection is None:
165-
raise ValueError("projection weights are not loaded")
166-
text_projection_transposed = ttnn.transpose(self.text_projection, -2, -1)
160+
if self.text_projection is not None:
161+
text_projection_transposed = ttnn.transpose(self.text_projection.data, -2, -1)
167162
pooled_output = ttnn.matmul(
168163
pooled_output, text_projection_transposed, compute_kernel_config=self.compute_kernel_config
169164
)
@@ -252,14 +247,9 @@ def __init__(
252247
fp32_dest_acc_en=True,
253248
packer_l1_acc=True,
254249
)
255-
self.layers = [
250+
self.layers = ModuleList(
256251
CLIPEncoderLayer(config, mesh_device, ccl_manager, parallel_config) for _ in range(config.num_hidden_layers)
257-
]
258-
259-
def load_torch_state_dict(self, state_dict):
260-
layer_states = indexed_substates(state_dict, "layers")
261-
for layer, layer_state in zip(self.layers, layer_states):
262-
layer.load_torch_state_dict(layer_state)
252+
)
263253

264254
def forward(
265255
self,
@@ -297,8 +287,6 @@ def __init__(
297287
fp32_dest_acc_en=True,
298288
packer_l1_acc=True,
299289
)
300-
self.layer_norm1 = None
301-
self.layer_norm2 = None
302290
self.layer_norm_eps = config.layer_norm_eps
303291
self.self_attn = CLIPAttention(config, mesh_device, ccl_manager, parallel_config)
304292
self.parallel_config = parallel_config
@@ -320,31 +308,23 @@ def __init__(
320308
)
321309
self.ccl_manager = ccl_manager
322310

323-
def load_torch_state_dict(self, state_dict):
324-
self.layer_norm1 = bf16_tensor(
325-
state_dict["layer_norm1.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
326-
)
327-
self.layer_norm1_bias = bf16_tensor(
328-
state_dict["layer_norm1.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
329-
)
330-
self.layer_norm2 = bf16_tensor(
331-
state_dict["layer_norm2.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
332-
)
333-
self.layer_norm2_bias = bf16_tensor(
334-
state_dict["layer_norm2.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
335-
)
311+
self.layer_norm1 = Parameter(total_shape=[config.embed_dim], device=mesh_device)
312+
self.layer_norm1_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
313+
self.layer_norm2 = Parameter(total_shape=[config.embed_dim], device=mesh_device)
314+
self.layer_norm2_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
336315

337-
self.self_attn.load_torch_state_dict(substate(state_dict, "self_attn"))
316+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
317+
if "layer_norm1.weight" in state:
318+
state["layer_norm1"] = state.pop("layer_norm1.weight")
319+
if "layer_norm1.bias" in state:
320+
state["layer_norm1_bias"] = state.pop("layer_norm1.bias")
321+
if "layer_norm2.weight" in state:
322+
state["layer_norm2"] = state.pop("layer_norm2.weight")
323+
if "layer_norm2.bias" in state:
324+
state["layer_norm2_bias"] = state.pop("layer_norm2.bias")
338325

339-
# remap MLP keys from fc1/fc2 to ff1/ff2 format
340-
mlp_state = substate(state_dict, "mlp")
341-
remapped_mlp_state = {
342-
"ff1.weight": mlp_state["fc1.weight"],
343-
"ff1.bias": mlp_state["fc1.bias"],
344-
"ff2.weight": mlp_state["fc2.weight"],
345-
"ff2.bias": mlp_state["fc2.bias"],
346-
}
347-
self.mlp.load_torch_state_dict(remapped_mlp_state)
326+
rename_substate(state, "mlp.fc1", "mlp.ff1")
327+
rename_substate(state, "mlp.fc2", "mlp.ff2")
348328

349329
def forward(
350330
self,
@@ -356,8 +336,8 @@ def forward(
356336
residual = hidden_states
357337
hidden_states = ttnn.layer_norm(
358338
hidden_states,
359-
weight=self.layer_norm1,
360-
bias=self.layer_norm1_bias,
339+
weight=self.layer_norm1.data,
340+
bias=self.layer_norm1_bias.data,
361341
epsilon=self.layer_norm_eps,
362342
compute_kernel_config=self.compute_kernel_config,
363343
)
@@ -367,8 +347,8 @@ def forward(
367347
residual = hidden_states
368348
hidden_states = ttnn.layer_norm(
369349
hidden_states,
370-
weight=self.layer_norm2,
371-
bias=self.layer_norm2_bias,
350+
weight=self.layer_norm2.data,
351+
bias=self.layer_norm2_bias.data,
372352
epsilon=self.layer_norm_eps,
373353
compute_kernel_config=self.compute_kernel_config,
374354
)
@@ -466,11 +446,8 @@ def __init__(
466446
self.v_proj = Linear(in_features=self.embed_dim, out_features=self.embed_dim, mesh_device=self.mesh_device)
467447
self.o_proj = Linear(in_features=self.embed_dim, out_features=self.embed_dim, mesh_device=self.mesh_device)
468448

469-
def load_torch_state_dict(self, state_dict):
470-
self.q_proj.load_torch_state_dict(substate(state_dict, "q_proj"))
471-
self.k_proj.load_torch_state_dict(substate(state_dict, "k_proj"))
472-
self.v_proj.load_torch_state_dict(substate(state_dict, "v_proj"))
473-
self.o_proj.load_torch_state_dict(substate(state_dict, "out_proj"))
449+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
450+
rename_substate(state, "out_proj", "o_proj")
474451

475452
def forward(self, hidden_states, causal_attention_mask):
476453
batch_size, seq_length, _ = hidden_states.shape
@@ -569,32 +546,22 @@ def __init__(self, config, mesh_device: ttnn.Device) -> None:
569546
self.config = config
570547
self.mesh_device = mesh_device
571548

572-
self.token_embedding = None
573-
self.position_embedding = None
574-
575-
def load_torch_state_dict(self, state_dict):
576-
self.token_embedding = bf16_tensor(
577-
state_dict["token_embedding.weight"], device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT
549+
self.token_embedding = Parameter(
550+
total_shape=[config.vocab_size, config.embed_dim],
551+
device=mesh_device,
552+
layout=ttnn.ROW_MAJOR_LAYOUT,
578553
)
579-
self.position_embedding = bf16_tensor(
580-
state_dict["position_embedding.weight"], device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT
554+
self.position_embedding = Parameter(
555+
total_shape=[config.max_prompt_length, config.embed_dim],
556+
device=mesh_device,
557+
layout=ttnn.ROW_MAJOR_LAYOUT,
581558
)
582559

583-
# TODO: Move to parameters to reuse module functionality
584-
def to_cached_state_dict(self, path_prefix, path_suffix=".tensorbin"):
585-
cache_dict = {}
586-
token_embedding_weights_path = path_prefix + "token_embedding_weights" + path_suffix
587-
position_embedding_weights_path = path_prefix + "position_embedding_weights" + path_suffix
588-
ttnn.dump_tensor(token_embedding_weights_path, self.token_embedding)
589-
ttnn.dump_tensor(position_embedding_weights_path, self.position_embedding)
590-
cache_dict["token_embedding"] = token_embedding_weights_path
591-
cache_dict["position_embedding"] = position_embedding_weights_path
592-
593-
return cache_dict
594-
595-
def from_cached_state_dict(self, cache_dict):
596-
self.token_embedding = ttnn.load_tensor(cache_dict["token_embedding"], device=self.mesh_device)
597-
self.position_embedding = ttnn.load_tensor(cache_dict["position_embedding"], device=self.mesh_device)
560+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
561+
if "token_embedding.weight" in state:
562+
state["token_embedding"] = state.pop("token_embedding.weight")
563+
if "position_embedding.weight" in state:
564+
state["position_embedding"] = state.pop("position_embedding.weight")
598565

599566
def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
600567
seq_len = prompt.shape[-1]
@@ -603,11 +570,11 @@ def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
603570
prompt = prompt[:, : self.config.max_prompt_length]
604571
seq_len = self.config.max_prompt_length
605572

606-
input_embeddings = ttnn.embedding(prompt, self.token_embedding, layout=ttnn.TILE_LAYOUT)
573+
input_embeddings = ttnn.embedding(prompt, self.token_embedding.data, layout=ttnn.TILE_LAYOUT)
607574

608575
position_ids = torch.arange(seq_len).expand((1, -1)) # shape: (1, seq_len)
609576
position_ids_ttnn = ttnn.from_torch(position_ids, dtype=ttnn.uint32, layout=ttnn.TILE_LAYOUT, device=device)
610-
position_embeddings = ttnn.embedding(position_ids_ttnn, self.position_embedding, layout=ttnn.TILE_LAYOUT)
577+
position_embeddings = ttnn.embedding(position_ids_ttnn, self.position_embedding.data, layout=ttnn.TILE_LAYOUT)
611578

612579
return input_embeddings + position_embeddings
613580

0 commit comments

Comments
 (0)