1919from typing import TYPE_CHECKING , Annotated , Callable , Dict , Optional , Union
2020
2121import torch
22- import torch .nn .functional as F
2322from megatron .core import tensor_parallel
2423from megatron .core .fusions .fused_bias_dropout import get_bias_dropout_add
2524from megatron .core .inference .model_inference_wrappers .inference_wrapper_config import InferenceWrapperConfig
4241from megatron .core .transformer .transformer_config import TransformerConfig
4342from megatron .core .transformer .transformer_layer import TransformerLayer , TransformerLayerSubmodules
4443from megatron .core .utils import make_viewless_tensor
45- from torch import Tensor , nn
46-
47- from cosmos1 .models .autoregressive .nemo .cosmos import CosmosConfig , CosmosConfig4B , CosmosModel , RotaryEmbedding3D
44+ from torch import nn
45+
46+ from cosmos1 .models .autoregressive .nemo .cosmos import (
47+ CosmosConfig ,
48+ CosmosConfig4B ,
49+ CosmosConfig12B ,
50+ CosmosModel ,
51+ RotaryEmbedding3D ,
52+ )
4853from cosmos1 .models .autoregressive .nemo .inference .inference_controller import CosmosInferenceWrapper
4954from cosmos1 .utils import log
5055
5156if TYPE_CHECKING :
5257 from nemo .collections .common .tokenizers .tokenizer_spec import TokenizerSpec
5358
54- from megatron .core import InferenceParams
5559from megatron .core .packed_seq_params import PackedSeqParams
5660from megatron .core .transformer .transformer_block import TransformerBlock
5761from nemo .collections .llm .gpt .model .base import get_batch_on_this_context_parallel_rank
58- from nemo .collections .llm .gpt .model .llama import Llama3Config
5962from nemo .collections .llm .utils import Config
6063from nemo .lightning import OptimizerModule , io
6164from nemo .lightning .base import teardown
6467class CosmosTransformerBlock (TransformerBlock ):
6568 def forward (
6669 self ,
67- hidden_states : Tensor ,
68- attention_mask : Tensor ,
69- context : Tensor = None ,
70- context_mask : Tensor = None ,
71- rotary_pos_emb : Tensor = None ,
72- rotary_pos_cos : Tensor = None ,
73- rotary_pos_sin : Tensor = None ,
74- attention_bias : Tensor = None ,
75- inference_params : InferenceParams = None ,
70+ * args ,
7671 packed_seq_params : PackedSeqParams = None ,
7772 extra_positional_embeddings = None ,
73+ ** kwargs ,
7874 ):
7975 packed_seq_params = {"abs_pos_embed" : extra_positional_embeddings }
8076 return super ().forward (
81- hidden_states ,
82- attention_mask ,
83- context ,
84- context_mask ,
85- rotary_pos_emb ,
86- rotary_pos_cos ,
87- rotary_pos_sin ,
88- attention_bias ,
89- inference_params ,
90- packed_seq_params ,
77+ * args ,
78+ packed_seq_params = packed_seq_params ,
79+ ** kwargs ,
9180 )
9281
9382
@@ -361,7 +350,7 @@ def cosmos_data_step(dataloader_iter) -> Dict[str, torch.Tensor]:
361350 if "cu_seqlens" in _batch :
362351 raise ValueError ("Packed sequence cu_seqlens not supported" )
363352
364- required_device_keys .update (("context" , "abs_pos_embed" ))
353+ required_device_keys .update (("context" , "abs_pos_embed" , "action" ))
365354 if parallel_state .is_pipeline_first_stage ():
366355 required_device_keys .update (("tokens" , "position_ids" ))
367356 if parallel_state .is_pipeline_last_stage ():
@@ -398,38 +387,27 @@ def cosmos_forward_step(model, batch) -> torch.Tensor:
398387
399388
400389@dataclass
401- class CosmosConfigVideo2World5B (Llama3Config ):
402- qk_layernorm : bool = True
403- rope_dim : str = "3D"
390+ class CosmosVideo2WorldConfig :
404391 vocab_size : int = 64064
405392 output_layer_vocab_size : int = 64000
406- activation_func = F .silu
407- rotary_base : int = 500_000
408393 seq_length : int = 12864
409- num_layers : int = 16
410- hidden_size : int = 4096
411- ffn_hidden_size : int = 14336
412- num_attention_heads : int = 32
413- num_query_groups : int = 8
414- layernorm_epsilon : float = 1e-5
415- use_cpu_initialization : bool = True
416- make_vocab_size_divisible_by : int = 64
417- kv_channels : int = 128
418- crossattn_emb_size : int = 1024
419394 latent_shape = [5 , 40 , 64 ]
420395 pad_to_multiple_of = 64
421396 forward_step_fn : Callable = cosmos_forward_step
422397 transformer_layer_spec = get_cosmos_video2world_spec ()
423398 data_step_fn : Callable = cosmos_data_step
424399 attention_backend : AttnBackend = AttnBackend .flash
400+ crossattn_emb_size : int = 1024
401+ kv_channels : int = 128
402+ training_type : str | None = "text_to_video"
425403
426404 def configure_model (self , tokenizer ) -> "MCoreGPTModel" :
427405 self .transformer_layer_spec = get_cosmos_video2world_spec ()
428406 model = super ().configure_model (tokenizer )
429407 if self .rope_dim == "3D" :
430408 model .rotary_pos_emb = RotaryEmbedding3D (
431409 seq_len = self .seq_length ,
432- training_type = "text_to_video" ,
410+ training_type = self . training_type ,
433411 pad_to_multiple_of = self .pad_to_multiple_of ,
434412 kv_channels = self .kv_channels ,
435413 max_position_embeddings = self .seq_length ,
@@ -467,78 +445,13 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel":
467445
468446
469447@dataclass
470- class CosmosConfigVideo2World13B (Llama3Config ):
471- qk_layernorm : bool = True
472- rope_dim : str = "3D"
473- vocab_size : int = 64064
474- output_layer_vocab_size : int = 64000
475- activation_func = F .silu
476- rotary_base : int = 500_000
477- seq_length : int = 12864
478- num_layers : int = 40
479- hidden_size : int = 5120
480- ffn_hidden_size : int = 14336
481- num_attention_heads : int = 32
482- num_query_groups : int = 8
483- layernorm_epsilon : float = 1e-5
484- use_cpu_initialization : bool = True
485- make_vocab_size_divisible_by : int = 128
486- kv_channels : int = 128
487- crossattn_emb_size : int = 1024
488- original_latent_shape = [3 , 40 , 64 ]
489- apply_yarn : bool = True
490- yarn_beta_fast : int = 4
491- yarn_beta_slow : int = 1
492- yarn_scale : int = 2
493- original_seq_len = 8192
494- latent_shape = [5 , 40 , 64 ]
495- pad_to_multiple_of = 64
496- forward_step_fn : Callable = cosmos_forward_step
497- transformer_layer_spec = get_cosmos_video2world_spec ()
498- data_step_fn : Callable = cosmos_data_step
499- attention_backend : AttnBackend = AttnBackend .flash
448+ class CosmosConfigVideo2World5B (CosmosVideo2WorldConfig , CosmosConfig4B ):
449+ make_vocab_size_divisible_by : int = 64
500450
501- def configure_model (self , tokenizer ) -> "MCoreGPTModel" :
502- self .transformer_layer_spec = get_cosmos_video2world_spec ()
503- model = super ().configure_model (tokenizer )
504- if self .rope_dim == "3D" :
505- model .rotary_pos_emb = RotaryEmbedding3D (
506- seq_len = self .seq_length ,
507- training_type = "text_to_video" ,
508- pad_to_multiple_of = self .pad_to_multiple_of ,
509- kv_channels = self .kv_channels ,
510- max_position_embeddings = self .seq_length ,
511- original_max_position_embeddings = self .original_seq_len if hasattr (self , "original_seq_len" ) else None ,
512- rotary_base = self .rotary_base ,
513- apply_yarn = True if hasattr (self , "apply_yarn" ) else False ,
514- scale = self .yarn_scale if hasattr (self , "yarn_scale" ) else None ,
515- extrapolation_factor = 1 ,
516- attn_factor = 1 ,
517- beta_fast = self .yarn_beta_fast if hasattr (self , "yarn_beta_fast" ) else 32 ,
518- beta_slow = self .yarn_beta_slow if hasattr (self , "yarn_beta_slow" ) else 1 ,
519- latent_shape = self .latent_shape ,
520- original_latent_shape = self .original_latent_shape if hasattr (self , "original_latent_shape" ) else None ,
521- )
522- model .output_layer = tensor_parallel .ColumnParallelLinear (
523- self .hidden_size ,
524- self .output_layer_vocab_size ,
525- config = self ,
526- init_method = self .init_method ,
527- bias = False ,
528- skip_bias_add = False ,
529- gather_output = False ,
530- skip_weight_param_allocation = False ,
531- embedding_activation_buffer = None ,
532- grad_output_buffer = None ,
533- )
534451
535- model .decoder = CosmosTransformerBlock (
536- config = self ,
537- spec = self .transformer_layer_spec ,
538- pre_process = model .pre_process ,
539- post_process = model .post_process ,
540- )
541- return model
452+ @dataclass
453+ class CosmosConfigVideo2World13B (CosmosVideo2WorldConfig , CosmosConfig12B ):
454+ make_vocab_size_divisible_by : int = 128
542455
543456
544457class CosmosVideo2WorldModel (CosmosModel ):
@@ -549,7 +462,9 @@ def __init__(
549462 tokenizer : Optional ["TokenizerSpec" ] = None ,
550463 model_transform : Optional [Callable [[nn .Module ], nn .Module ]] = None ,
551464 ):
552- super ().__init__ (config or CosmosConfig4B (), optim = optim , tokenizer = tokenizer , model_transform = model_transform )
465+ super ().__init__ (
466+ config or CosmosConfigVideo2World5B (), optim = optim , tokenizer = tokenizer , model_transform = model_transform
467+ )
553468 self .config = config
554469
555470 def get_inference_wrapper (self , params_dtype , inference_batch_times_seqlen_threshold ) -> torch .Tensor :
0 commit comments