Skip to content

Commit f8c7d16

Browse files
committed
migrate all models to make full use of the Module base class
1 parent ef4f2bf commit f8c7d16

37 files changed

+1605
-2113
lines changed

models/tt_dit/encoders/clip/encoder_pair.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _load_encoder(self, checkpoint: str, *, use_torch: bool) -> CLIPEncoder | CL
7474
eos_token_id=2, # default EOS token ID for CLIP
7575
)
7676

77-
model.load_state_dict(torch_model.state_dict())
77+
model.load_torch_state_dict(torch_model.state_dict())
7878

7979
return model
8080

models/tt_dit/encoders/clip/model_clip.py

Lines changed: 63 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111

1212
from ...layers.feedforward import FeedForward, ParallelFeedForward
1313
from ...layers.linear import ColParallelLinear, Linear
14-
from ...layers.module import Module
14+
from ...layers.module import Module, ModuleList, Parameter
1515
from ...parallel.config import EncoderParallelConfig
1616
from ...parallel.manager import CCLManager
17-
from ...utils.substate import indexed_substates, substate
18-
from ...utils.tensor import bf16_tensor
17+
from ...utils.substate import rename_substate
1918

2019

2120
class CLIPConfig:
@@ -57,6 +56,7 @@ def __init__(
5756
layer_norm_eps: float = 1e-05,
5857
attention_dropout: float = 0.0,
5958
hidden_act: str = "quick_gelu",
59+
projection_dim: int | None = None,
6060
):
6161
self.vocab_size = vocab_size
6262
self.embed_dim = embed_dim
@@ -66,6 +66,7 @@ def __init__(
6666
self.max_prompt_length = max_prompt_length
6767
self.layer_norm_eps = layer_norm_eps
6868
self.attention_dropout = attention_dropout
69+
self.projection_dim = projection_dim
6970
if hidden_act == "gelu":
7071
self.hidden_act = "decomposed_gelu"
7172
else:
@@ -97,36 +98,33 @@ def __init__(
9798
self.embeddings = TextEmbeddings(config, mesh_device)
9899
self.eos_token_id = eos_token_id
99100
self.encoder = CLIPStack(config, self.mesh_device, self.ccl_manager, self.parallel_config)
100-
self.text_projection = None
101101

102-
def load_torch_state_dict(self, state_dict):
103-
self.embeddings.load_torch_state_dict(substate(state_dict, "text_model.embeddings"))
104-
self.encoder.load_torch_state_dict(substate(state_dict, "text_model.encoder"))
105-
106-
self.final_layer_norm = bf16_tensor(
107-
state_dict["text_model.final_layer_norm.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
108-
)
109-
self.final_layer_norm_bias = bf16_tensor(
110-
state_dict["text_model.final_layer_norm.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
102+
self.final_layer_norm = Parameter(total_shape=[config.embed_dim], device=mesh_device)
103+
self.final_layer_norm_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
104+
self.text_projection = (
105+
Parameter(total_shape=[config.embed_dim, config.projection_dim], device=mesh_device)
106+
if config.projection_dim is not None
107+
else None
111108
)
112-
if "text_projection.weight" in state_dict:
113-
self.text_projection = bf16_tensor(
114-
state_dict["text_projection.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
115-
)
116-
else:
117-
self.text_projection = None
109+
110+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
111+
rename_substate(state, "text_model.embeddings", "embeddings")
112+
rename_substate(state, "text_model.encoder", "encoder")
113+
114+
if "text_model.final_layer_norm.weight" in state:
115+
state["final_layer_norm"] = state.pop("text_model.final_layer_norm.weight")
116+
if "text_model.final_layer_norm.bias" in state:
117+
state["final_layer_norm_bias"] = state.pop("text_model.final_layer_norm.bias")
118+
if "text_projection.weight" in state:
119+
state["text_projection"] = state.pop("text_projection.weight")
118120

119121
def forward(
120122
self,
121123
prompt_tokenized: ttnn.Tensor,
122124
mesh_device: ttnn.Device,
123125
*,
124-
with_projection: bool | None = None,
125126
return_normalized_state: bool = False,
126127
) -> tuple[ttnn.Tensor, ...]:
127-
if with_projection is None:
128-
with_projection = self.text_projection is not None
129-
130128
hidden_states = self.embeddings(prompt_tokenized, mesh_device)
131129

132130
causal_attention_mask = create_4d_causal_attention_mask(
@@ -142,8 +140,8 @@ def forward(
142140
final_hidden_layer = encoder_output[-1] # final hidden layer
143141
normalized_final_state = ttnn.layer_norm( # final layer norm
144142
final_hidden_layer,
145-
weight=self.final_layer_norm,
146-
bias=self.final_layer_norm_bias,
143+
weight=self.final_layer_norm.data,
144+
bias=self.final_layer_norm_bias.data,
147145
epsilon=self.config.layer_norm_eps,
148146
compute_kernel_config=self.compute_kernel_config,
149147
)
@@ -160,11 +158,8 @@ def forward(
160158
ccl_manager=self.ccl_manager,
161159
)
162160

163-
# apply text projection if specified
164-
if with_projection:
165-
if self.text_projection is None:
166-
raise ValueError("projection weights are not loaded")
167-
text_projection_transposed = ttnn.transpose(self.text_projection, -2, -1)
161+
if self.text_projection is not None:
162+
text_projection_transposed = ttnn.transpose(self.text_projection.data, -2, -1)
168163
pooled_output = ttnn.matmul(
169164
pooled_output, text_projection_transposed, compute_kernel_config=self.compute_kernel_config
170165
)
@@ -253,14 +248,9 @@ def __init__(
253248
fp32_dest_acc_en=True,
254249
packer_l1_acc=True,
255250
)
256-
self.layers = [
251+
self.layers = ModuleList(
257252
CLIPEncoderLayer(config, mesh_device, ccl_manager, parallel_config) for _ in range(config.num_hidden_layers)
258-
]
259-
260-
def load_torch_state_dict(self, state_dict):
261-
layer_states = indexed_substates(state_dict, "layers")
262-
for layer, layer_state in zip(self.layers, layer_states):
263-
layer.load_torch_state_dict(layer_state)
253+
)
264254

265255
def forward(
266256
self,
@@ -298,8 +288,6 @@ def __init__(
298288
fp32_dest_acc_en=True,
299289
packer_l1_acc=True,
300290
)
301-
self.layer_norm1 = None
302-
self.layer_norm2 = None
303291
self.layer_norm_eps = config.layer_norm_eps
304292
self.self_attn = CLIPAttention(config, mesh_device, ccl_manager, parallel_config)
305293
self.parallel_config = parallel_config
@@ -321,31 +309,23 @@ def __init__(
321309
)
322310
self.ccl_manager = ccl_manager
323311

324-
def load_torch_state_dict(self, state_dict):
325-
self.layer_norm1 = bf16_tensor(
326-
state_dict["layer_norm1.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
327-
)
328-
self.layer_norm1_bias = bf16_tensor(
329-
state_dict["layer_norm1.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
330-
)
331-
self.layer_norm2 = bf16_tensor(
332-
state_dict["layer_norm2.weight"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
333-
)
334-
self.layer_norm2_bias = bf16_tensor(
335-
state_dict["layer_norm2.bias"], device=self.mesh_device, layout=ttnn.TILE_LAYOUT
336-
)
312+
self.layer_norm1 = Parameter(total_shape=[config.embed_dim], device=mesh_device)
313+
self.layer_norm1_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
314+
self.layer_norm2 = Parameter(total_shape=[config.embed_dim], device=mesh_device)
315+
self.layer_norm2_bias = Parameter(total_shape=[config.embed_dim], device=mesh_device)
337316

338-
self.self_attn.load_torch_state_dict(substate(state_dict, "self_attn"))
317+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
318+
if "layer_norm1.weight" in state:
319+
state["layer_norm1"] = state.pop("layer_norm1.weight")
320+
if "layer_norm1.bias" in state:
321+
state["layer_norm1_bias"] = state.pop("layer_norm1.bias")
322+
if "layer_norm2.weight" in state:
323+
state["layer_norm2"] = state.pop("layer_norm2.weight")
324+
if "layer_norm2.bias" in state:
325+
state["layer_norm2_bias"] = state.pop("layer_norm2.bias")
339326

340-
# remap MLP keys from fc1/fc2 to ff1/ff2 format
341-
mlp_state = substate(state_dict, "mlp")
342-
remapped_mlp_state = {
343-
"ff1.weight": mlp_state["fc1.weight"],
344-
"ff1.bias": mlp_state["fc1.bias"],
345-
"ff2.weight": mlp_state["fc2.weight"],
346-
"ff2.bias": mlp_state["fc2.bias"],
347-
}
348-
self.mlp.load_torch_state_dict(remapped_mlp_state)
327+
rename_substate(state, "mlp.fc1", "mlp.ff1")
328+
rename_substate(state, "mlp.fc2", "mlp.ff2")
349329

350330
def forward(
351331
self,
@@ -357,8 +337,8 @@ def forward(
357337
residual = hidden_states
358338
hidden_states = ttnn.layer_norm(
359339
hidden_states,
360-
weight=self.layer_norm1,
361-
bias=self.layer_norm1_bias,
340+
weight=self.layer_norm1.data,
341+
bias=self.layer_norm1_bias.data,
362342
epsilon=self.layer_norm_eps,
363343
compute_kernel_config=self.compute_kernel_config,
364344
)
@@ -368,8 +348,8 @@ def forward(
368348
residual = hidden_states
369349
hidden_states = ttnn.layer_norm(
370350
hidden_states,
371-
weight=self.layer_norm2,
372-
bias=self.layer_norm2_bias,
351+
weight=self.layer_norm2.data,
352+
bias=self.layer_norm2_bias.data,
373353
epsilon=self.layer_norm_eps,
374354
compute_kernel_config=self.compute_kernel_config,
375355
)
@@ -467,11 +447,8 @@ def __init__(
467447
self.v_proj = Linear(in_features=self.embed_dim, out_features=self.embed_dim, mesh_device=self.mesh_device)
468448
self.o_proj = Linear(in_features=self.embed_dim, out_features=self.embed_dim, mesh_device=self.mesh_device)
469449

470-
def load_torch_state_dict(self, state_dict):
471-
self.q_proj.load_torch_state_dict(substate(state_dict, "q_proj"))
472-
self.k_proj.load_torch_state_dict(substate(state_dict, "k_proj"))
473-
self.v_proj.load_torch_state_dict(substate(state_dict, "v_proj"))
474-
self.o_proj.load_torch_state_dict(substate(state_dict, "out_proj"))
450+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
451+
rename_substate(state, "out_proj", "o_proj")
475452

476453
def forward(self, hidden_states, causal_attention_mask):
477454
batch_size, seq_length, _ = hidden_states.shape
@@ -570,32 +547,22 @@ def __init__(self, config, mesh_device: ttnn.Device) -> None:
570547
self.config = config
571548
self.mesh_device = mesh_device
572549

573-
self.token_embedding = None
574-
self.position_embedding = None
575-
576-
def load_torch_state_dict(self, state_dict):
577-
self.token_embedding = bf16_tensor(
578-
state_dict["token_embedding.weight"], device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT
550+
self.token_embedding = Parameter(
551+
total_shape=[config.vocab_size, config.embed_dim],
552+
device=mesh_device,
553+
layout=ttnn.ROW_MAJOR_LAYOUT,
579554
)
580-
self.position_embedding = bf16_tensor(
581-
state_dict["position_embedding.weight"], device=self.mesh_device, layout=ttnn.ROW_MAJOR_LAYOUT
555+
self.position_embedding = Parameter(
556+
total_shape=[config.max_prompt_length, config.embed_dim],
557+
device=mesh_device,
558+
layout=ttnn.ROW_MAJOR_LAYOUT,
582559
)
583560

584-
# TODO: Move to parameters to reuse module functionality
585-
def to_cached_state_dict(self, path_prefix, path_suffix=".tensorbin"):
586-
cache_dict = {}
587-
token_embedding_weights_path = path_prefix + "token_embedding_weights" + path_suffix
588-
position_embedding_weights_path = path_prefix + "position_embedding_weights" + path_suffix
589-
ttnn.dump_tensor(token_embedding_weights_path, self.token_embedding)
590-
ttnn.dump_tensor(position_embedding_weights_path, self.position_embedding)
591-
cache_dict["token_embedding"] = token_embedding_weights_path
592-
cache_dict["position_embedding"] = position_embedding_weights_path
593-
594-
return cache_dict
595-
596-
def from_cached_state_dict(self, cache_dict):
597-
self.token_embedding = ttnn.load_tensor(cache_dict["token_embedding"], device=self.mesh_device)
598-
self.position_embedding = ttnn.load_tensor(cache_dict["position_embedding"], device=self.mesh_device)
561+
def _prepare_torch_state(self, state: dict[str, torch.Tensor]) -> None:
562+
if "token_embedding.weight" in state:
563+
state["token_embedding"] = state.pop("token_embedding.weight")
564+
if "position_embedding.weight" in state:
565+
state["position_embedding"] = state.pop("position_embedding.weight")
599566

600567
def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
601568
seq_len = prompt.shape[-1]
@@ -604,11 +571,11 @@ def forward(self, prompt: ttnn.Tensor, device: ttnn.Device) -> ttnn.Tensor:
604571
prompt = prompt[:, : self.config.max_prompt_length]
605572
seq_len = self.config.max_prompt_length
606573

607-
input_embeddings = ttnn.embedding(prompt, self.token_embedding, layout=ttnn.TILE_LAYOUT)
574+
input_embeddings = ttnn.embedding(prompt, self.token_embedding.data, layout=ttnn.TILE_LAYOUT)
608575

609576
position_ids = torch.arange(seq_len).expand((1, -1)) # shape: (1, seq_len)
610577
position_ids_ttnn = ttnn.from_torch(position_ids, dtype=ttnn.uint32, layout=ttnn.TILE_LAYOUT, device=device)
611-
position_embeddings = ttnn.embedding(position_ids_ttnn, self.position_embedding, layout=ttnn.TILE_LAYOUT)
578+
position_embeddings = ttnn.embedding(position_ids_ttnn, self.position_embedding.data, layout=ttnn.TILE_LAYOUT)
612579

613580
return input_embeddings + position_embeddings
614581

0 commit comments

Comments
 (0)