diff --git a/cortex/config/hydra/branches/transformer_decoder.yaml b/cortex/config/hydra/branches/transformer_decoder.yaml new file mode 100644 index 0000000..38d6d4c --- /dev/null +++ b/cortex/config/hydra/branches/transformer_decoder.yaml @@ -0,0 +1,13 @@ +# Transformer Decoder Branch configuration + +_target_: cortex.model.branch.TransformerDecoderBranch +in_dim: ??? # Must be provided, should match trunk output dimension +out_dim: ??? # Must be provided, output dimension of the branch +num_layers: 2 # Number of transformer decoder layers +nhead: 8 # Number of attention heads +dim_feedforward: null # Optional, if null will be set to 4 * in_dim +dropout: 0.1 # Dropout probability +activation: "relu" # Activation function for the transformer +layer_norm_eps: 1.0e-5 # Epsilon value for layer normalization +batch_first: true # Input tensors have batch dimension first (batch, seq, features) +pooling_type: "mean" # Pooling strategy for sequence features ("mean" or "weighted_mean") diff --git a/cortex/config/hydra/branches/transformer_encoder.yaml b/cortex/config/hydra/branches/transformer_encoder.yaml new file mode 100644 index 0000000..38ba726 --- /dev/null +++ b/cortex/config/hydra/branches/transformer_encoder.yaml @@ -0,0 +1,13 @@ +# Transformer Encoder Branch configuration + +_target_: cortex.model.branch.TransformerEncoderBranch +in_dim: ??? # Must be provided, should match trunk output dimension +out_dim: ??? # Must be provided, output dimension of the branch +num_layers: 2 # Number of transformer encoder layers +nhead: 8 # Number of attention heads +dim_feedforward: null # Optional, if null will be set to 4 * in_dim +dropout: 0.1 # Dropout probability +activation: "relu" # Activation function for the transformer +layer_norm_eps: 1.0e-5 # Epsilon value for layer normalization +batch_first: true # Input tensors have batch dimension first (batch, seq, features) +pooling_type: "mean" # Pooling strategy for sequence features ("mean" or "weighted_mean") diff --git a/cortex/config/hydra/roots/transformer_decoder.yaml b/cortex/config/hydra/roots/transformer_decoder.yaml new file mode 100644 index 0000000..3c1ffd8 --- /dev/null +++ b/cortex/config/hydra/roots/transformer_decoder.yaml @@ -0,0 +1,12 @@ +# Transformer Decoder Root configuration + +_target_: cortex.model.root.TransformerDecoderRoot +tokenizer_transform: ??? # Must be provided, instance of HuggingFaceTokenizerTransform +model_name_or_path: ??? # Must be provided, Hugging Face model identifier or path +max_len: 512 # Maximum sequence length for padding/truncation +use_pretrained: true # Whether to use pre-trained weights from HF +attn_implementation: "sdpa" # Attention implementation ("sdpa", "flash_attention_2", "eager") +config_overrides: null # Optional overrides if use_pretrained=false +corruption_process: null # Optional corruption process for masked language modeling +train_transforms: null # Optional transforms applied only during training +eval_transforms: null # Optional transforms applied only during evaluation diff --git a/cortex/config/hydra/roots/transformer_encoder.yaml b/cortex/config/hydra/roots/transformer_encoder.yaml new file mode 100644 index 0000000..72d1f7c --- /dev/null +++ b/cortex/config/hydra/roots/transformer_encoder.yaml @@ -0,0 +1,12 @@ +# Transformer Encoder Root configuration + +_target_: cortex.model.root.TransformerEncoderRoot +tokenizer_transform: ??? # Must be provided, instance of HuggingFaceTokenizerTransform +model_name_or_path: ??? # Must be provided, Hugging Face model identifier or path +max_len: 512 # Maximum sequence length for padding/truncation +use_pretrained: true # Whether to use pre-trained weights from HF +attn_implementation: "sdpa" # Attention implementation ("sdpa", "flash_attention_2", "eager") +config_overrides: null # Optional overrides if use_pretrained=false +corruption_process: null # Optional corruption process for masked language modeling +train_transforms: null # Optional transforms applied only during training +eval_transforms: null # Optional transforms applied only during evaluation diff --git a/cortex/model/branch/__init__.py b/cortex/model/branch/__init__.py index 16ed371..97346ab 100644 --- a/cortex/model/branch/__init__.py +++ b/cortex/model/branch/__init__.py @@ -1,9 +1,15 @@ from ._abstract_branch import BranchNode, BranchNodeOutput from ._conv1d_branch import Conv1dBranch, Conv1dBranchOutput +from ._transformer_decoder_branch import TransformerDecoderBranch, TransformerDecoderBranchOutput +from ._transformer_encoder_branch import TransformerEncoderBranch, TransformerEncoderBranchOutput __all__ = [ "BranchNode", "BranchNodeOutput", "Conv1dBranch", "Conv1dBranchOutput", + "TransformerEncoderBranch", + "TransformerEncoderBranchOutput", + "TransformerDecoderBranch", + "TransformerDecoderBranchOutput", ] diff --git a/cortex/model/branch/_transformer_decoder_branch.py b/cortex/model/branch/_transformer_decoder_branch.py new file mode 100644 index 0000000..d2f6616 --- /dev/null +++ b/cortex/model/branch/_transformer_decoder_branch.py @@ -0,0 +1,142 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from cortex.model.branch import BranchNode, BranchNodeOutput +from cortex.model.elemental import MeanPooling, WeightedMeanPooling +from cortex.model.trunk import PaddedTrunkOutput + + +@dataclass +class TransformerDecoderBranchOutput(BranchNodeOutput): + """Output of TransformerDecoderBranch.""" + + branch_features: torch.Tensor + branch_mask: torch.Tensor + pooled_features: torch.Tensor + + +class TransformerDecoderBranch(BranchNode): + """ + Branch node that applies additional Transformer decoder layers with causal self-attention + to features from the trunk. + + Example Hydra Config: + ```yaml + branches: + transformer_decoder_branch: + _target_: cortex.model.branch.TransformerDecoderBranch + in_dim: 512 # Should match trunk output + out_dim: 512 + num_layers: 2 + nhead: 8 + dim_feedforward: 2048 # Optional, defaults to 4 * in_dim + dropout: 0.1 + activation: "relu" + layer_norm_eps: 1e-5 + batch_first: True + pooling_type: "mean" + ``` + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_layers: int, + nhead: int, + dim_feedforward: Optional[int] = None, + dropout: float = 0.1, + activation: str = "relu", + layer_norm_eps: float = 1e-5, + batch_first: bool = True, + pooling_type: str = "mean", + **kwargs, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Set default dim_feedforward if not provided + if dim_feedforward is None: + dim_feedforward = 4 * in_dim + + # Create decoder layer and stack them + decoder_layer = nn.TransformerDecoderLayer( + d_model=in_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + ) + + self.transformer_layers = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_layers, + ) + + # Add projection layer if dimensions don't match + if in_dim != out_dim: + self.projection = nn.Linear(in_dim, out_dim) + else: + self.projection = None + + # Set up pooling operation + if pooling_type == "mean": + self.pooling_op = MeanPooling() + elif pooling_type == "weighted_mean": + self.pooling_op = WeightedMeanPooling(out_dim) + else: + raise ValueError(f"Unsupported pooling_type: {pooling_type}") + + def forward( + self, + trunk_outputs: PaddedTrunkOutput, + ) -> TransformerDecoderBranchOutput: + """ + Args: + trunk_outputs: PaddedTrunkOutput containing trunk_features and padding_mask + + Returns: + TransformerDecoderBranchOutput containing: + branch_features: Sequence features after transformer layers + branch_mask: Padding mask for the output sequence + pooled_features: Pooled sequence features + """ + features = trunk_outputs.trunk_features + padding_mask = trunk_outputs.padding_mask + + # Convert padding_mask to tgt_key_padding_mask for transformer + # PyTorch transformer expects True for positions to be *masked* + tgt_key_padding_mask = ~padding_mask.bool() + + # Create causal mask to ensure autoregressive attention + seq_len = features.size(1) + causal_mask = nn.Transformer.generate_square_subsequent_mask(sz=seq_len, device=features.device) + + # Apply transformer layers + # For self-attention only, we pass features as both tgt and memory + branch_features = self.transformer_layers( + tgt=features, + memory=features, # Use features as memory for self-attention only + tgt_mask=causal_mask, # Apply causal masking + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=tgt_key_padding_mask, # Same as tgt padding mask + ) + + # Apply projection if needed + if self.projection is not None: + branch_features = self.projection(branch_features) + + # Pool features + pooled_features = self.pooling_op(branch_features, padding_mask) + + return TransformerDecoderBranchOutput( + branch_features=branch_features.contiguous(), + branch_mask=padding_mask, + pooled_features=pooled_features, + ) diff --git a/cortex/model/branch/_transformer_encoder_branch.py b/cortex/model/branch/_transformer_encoder_branch.py new file mode 100644 index 0000000..8c18f93 --- /dev/null +++ b/cortex/model/branch/_transformer_encoder_branch.py @@ -0,0 +1,130 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from cortex.model.branch import BranchNode, BranchNodeOutput +from cortex.model.elemental import MeanPooling, WeightedMeanPooling +from cortex.model.trunk import PaddedTrunkOutput + + +@dataclass +class TransformerEncoderBranchOutput(BranchNodeOutput): + """Output of TransformerEncoderBranch.""" + + branch_features: torch.Tensor + branch_mask: torch.Tensor + pooled_features: torch.Tensor + + +class TransformerEncoderBranch(BranchNode): + """ + Branch node that applies additional Transformer encoder layers to features from the trunk. + + Example Hydra Config: + ```yaml + branches: + transformer_encoder_branch: + _target_: cortex.model.branch.TransformerEncoderBranch + in_dim: 512 # Should match trunk output + out_dim: 512 + num_layers: 2 + nhead: 8 + dim_feedforward: 2048 # Optional, defaults to 4 * in_dim + dropout: 0.1 + activation: "relu" + layer_norm_eps: 1e-5 + batch_first: True + pooling_type: "mean" + ``` + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + num_layers: int, + nhead: int, + dim_feedforward: Optional[int] = None, + dropout: float = 0.1, + activation: str = "relu", + layer_norm_eps: float = 1e-5, + batch_first: bool = True, + pooling_type: str = "mean", + **kwargs, + ): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # Set default dim_feedforward if not provided + if dim_feedforward is None: + dim_feedforward = 4 * in_dim + + # Create encoder layer and stack them + encoder_layer = nn.TransformerEncoderLayer( + d_model=in_dim, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + layer_norm_eps=layer_norm_eps, + batch_first=batch_first, + ) + + self.transformer_layers = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_layers, + ) + + # Add projection layer if dimensions don't match + if in_dim != out_dim: + self.projection = nn.Linear(in_dim, out_dim) + else: + self.projection = None + + # Set up pooling operation + if pooling_type == "mean": + self.pooling_op = MeanPooling() + elif pooling_type == "weighted_mean": + self.pooling_op = WeightedMeanPooling(out_dim) + else: + raise ValueError(f"Unsupported pooling_type: {pooling_type}") + + def forward( + self, + trunk_outputs: PaddedTrunkOutput, + ) -> TransformerEncoderBranchOutput: + """ + Args: + trunk_outputs: PaddedTrunkOutput containing trunk_features and padding_mask + + Returns: + TransformerEncoderBranchOutput containing: + branch_features: Sequence features after transformer layers + branch_mask: Padding mask for the output sequence + pooled_features: Pooled sequence features + """ + features = trunk_outputs.trunk_features + padding_mask = trunk_outputs.padding_mask + + # Convert padding_mask to src_key_padding_mask for transformer + # PyTorch transformer expects True for positions to be *masked* + src_key_padding_mask = ~padding_mask.bool() + + # Apply transformer layers + branch_features = self.transformer_layers(src=features, src_key_padding_mask=src_key_padding_mask) + + # Apply projection if needed + if self.projection is not None: + branch_features = self.projection(branch_features) + + # Pool features + pooled_features = self.pooling_op(branch_features, padding_mask) + + return TransformerEncoderBranchOutput( + branch_features=branch_features.contiguous(), + branch_mask=padding_mask, + pooled_features=pooled_features, + ) diff --git a/cortex/model/root/__init__.py b/cortex/model/root/__init__.py index 0a5f1fe..ad3e19b 100644 --- a/cortex/model/root/__init__.py +++ b/cortex/model/root/__init__.py @@ -1,9 +1,15 @@ from ._abstract_root import RootNode, RootNodeOutput from ._conv1d_root import Conv1dRoot, Conv1dRootOutput +from ._transformer_decoder_root import TransformerDecoderRoot, TransformerDecoderRootOutput +from ._transformer_encoder_root import TransformerEncoderRoot, TransformerEncoderRootOutput __all__ = [ "RootNode", "RootNodeOutput", "Conv1dRoot", "Conv1dRootOutput", + "TransformerEncoderRoot", + "TransformerEncoderRootOutput", + "TransformerDecoderRoot", + "TransformerDecoderRootOutput", ] diff --git a/cortex/model/root/_transformer_decoder_root.py b/cortex/model/root/_transformer_decoder_root.py new file mode 100644 index 0000000..bb7f905 --- /dev/null +++ b/cortex/model/root/_transformer_decoder_root.py @@ -0,0 +1,316 @@ +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import LongTensor, nn +from transformers import AutoConfig, AutoModelForCausalLM + +from cortex.corruption import CorruptionProcess, MaskCorruptionProcess +from cortex.model.root import RootNode, RootNodeOutput +from cortex.transforms import HuggingFaceTokenizerTransform, PadTransform, ToTensor + + +@dataclass +class TransformerDecoderRootOutput(RootNodeOutput): + """Output of TransformerDecoderRoot.""" + + root_features: torch.Tensor + padding_mask: torch.Tensor + corrupt_frac: Optional[torch.Tensor] = None + src_tok_idxs: Optional[torch.LongTensor] = None + tgt_tok_idxs: Optional[torch.LongTensor] = None + src_tok_embs: Optional[torch.Tensor] = None + is_corrupted: Optional[torch.Tensor] = None + + +class TransformerDecoderRoot(RootNode): + """ + A root node that wraps a Hugging Face transformer decoder-only model (e.g., GPT-2, LBSTER decoders). + + Example Hydra Config: + ```yaml + roots: + text_decoder: + _target_: cortex.model.root.TransformerDecoderRoot + tokenizer_transform: ??? # Needs instantiation elsewhere + model_name_or_path: "gpt2" + use_pretrained: True + max_len: 512 + attn_implementation: "sdpa" + out_dim: 768 # Example, will be inferred + ``` + """ + + def __init__( + self, + tokenizer_transform: HuggingFaceTokenizerTransform, + model_name_or_path: str, + max_len: int, + out_dim: int = None, + use_pretrained: bool = True, + attn_implementation: Optional[str] = "sdpa", + config_overrides: Optional[DictConfig] = None, + corruption_process: Optional[CorruptionProcess] = None, + train_transforms=None, + eval_transforms=None, + **kwargs, + ) -> None: + super().__init__() + self.tokenizer = tokenizer_transform.tokenizer + self.pad_tok_idx = self.tokenizer.padding_idx + self.masking_idx = getattr(self.tokenizer, "masking_idx", None) + self.max_len = max_len + + # Load or create Hugging Face model configuration + if use_pretrained: + self.transformer = AutoModelForCausalLM.from_pretrained( + model_name_or_path, attn_implementation=attn_implementation + ) + else: + config = AutoConfig.from_pretrained(model_name_or_path) + # Apply configuration overrides if specified + if config_overrides is not None: + for key, value in config_overrides.items(): + setattr(config, key, value) + self.transformer = AutoModelForCausalLM.from_config(config) + + # Determine output dimension from model + self._out_dim = self.transformer.config.hidden_size + + # Validate against provided out_dim if specified + if out_dim is not None and out_dim != self._out_dim: + warnings.warn( + f"Provided out_dim ({out_dim}) does not match model's hidden_size ({self._out_dim}). " + f"Using model's hidden_size.", + stacklevel=2, + ) + + # Set up transforms + shared_transforms = [ + tokenizer_transform, + ToTensor(padding_value=self.pad_tok_idx), + PadTransform(max_length=self.max_len, pad_value=self.pad_tok_idx), + ] + train_transforms = [] if train_transforms is None else list(train_transforms.values()) + eval_transforms = [] if eval_transforms is None else list(eval_transforms.values()) + self.train_transform = nn.Sequential(*(train_transforms + shared_transforms)) + self.eval_transform = nn.Sequential(*(eval_transforms + shared_transforms)) + + self.corruption_process = corruption_process + + @property + def out_dim(self): + return self._out_dim + + @property + def device(self): + return next(self.transformer.parameters()).device + + def initialize_weights(self, **kwargs): + # Default random initialization or handled by HF + pass + + def init_seq( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: float = 0.0, + **kwargs, + ): + # infer input type if not specified + if inputs is not None: + if isinstance(inputs, np.ndarray): + seq_array = inputs + if isinstance(inputs, LongTensor): + tgt_tok_idxs = inputs + elif isinstance(inputs, torch.Tensor): + src_tok_embs = inputs + msg = "inputs is deprecated, use a specific argument instead" + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + # Determine batch size from any available input + batch_size = None + if seq_array is not None: + batch_size = seq_array.shape[0] + elif tgt_tok_idxs is not None: + batch_size = tgt_tok_idxs.shape[0] + elif src_tok_embs is not None: + batch_size = src_tok_embs.shape[0] + + # Fallback to default batch size of 1 if no inputs are provided + if batch_size is None: + batch_size = 1 + + if "mask_frac" in kwargs: + corrupt_frac = kwargs["mask_frac"] + msg = "mask_frac is deprecated, use corrupt_frac instead." + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + if self.corruption_process is not None and corrupt_frac is None: + corrupt_frac = self.corruption_process.sample_corrupt_frac(n=batch_size).to(self.device) + elif isinstance(corrupt_frac, float): + corrupt_frac = torch.full((batch_size,), corrupt_frac, device=self.device) + elif isinstance(corrupt_frac, torch.Tensor): + # Move tensor to the correct device + corrupt_frac = corrupt_frac.to(self.device) + else: + corrupt_frac = torch.full((batch_size,), 0.0, device=self.device) + + return seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac + + def tokenize_seq( + self, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + ): + # begin forward pass from raw sequence + if seq_array is not None: + assert tgt_tok_idxs is None + assert src_tok_embs is None + if self.training: + tgt_tok_idxs = self.train_transform(seq_array) + else: + tgt_tok_idxs = self.eval_transform(seq_array) + tgt_tok_idxs = tgt_tok_idxs.to(self.device) + + # truncate token sequence to max context length + if tgt_tok_idxs is not None: + assert src_tok_embs is None + # truncate to max context length, keep final stop token + if tgt_tok_idxs.size(-1) > self.max_len: + tmp_tok_idxs = tgt_tok_idxs[..., : self.max_len - 1] + tgt_tok_idxs = torch.cat([tmp_tok_idxs, tgt_tok_idxs[..., -1:]], dim=-1) + + if corruption_allowed is None and tgt_tok_idxs is not None: + corruption_allowed = self.tokenizer.get_corruptible_mask(tgt_tok_idxs) + + # begin forward pass from tokenized sequence + if tgt_tok_idxs is not None: + # apply masking corruption + if isinstance(self.corruption_process, MaskCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + src_tok_idxs, is_corrupted = self.corruption_process( + x_start=tgt_tok_idxs, + mask_val=self.masking_idx or self.tokenizer.mask_token_id, + corruption_allowed=corruption_allowed, + corrupt_frac=corrupt_frac, + ) + else: + src_tok_idxs = tgt_tok_idxs + is_corrupted = ( + torch.full_like(src_tok_idxs, False, dtype=torch.bool) if is_corrupted is None else is_corrupted + ) + + padding_mask = src_tok_idxs != self.pad_tok_idx + + if src_tok_embs is not None: + assert seq_array is None + assert padding_mask is not None + src_tok_idxs = None + + return ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) + + def forward( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + **kwargs, + ) -> TransformerDecoderRootOutput: + """ + Args: + seq_array: (batch_size,) array of discrete sequences (e.g. text strings) + tgt_tok_idxs: Optional pre-tokenized inputs + src_tok_embs: Optional pre-embedded inputs + padding_mask: Optional padding mask for pre-embedded inputs + corrupt_frac: Fraction of tokens to corrupt + is_corrupted: Optional pre-computed corruption mask + corruption_allowed: Optional mask of tokens that can be corrupted + + Returns: + TransformerDecoderRootOutput containing: + root_features: Transformer decoder output representations + padding_mask: Attention mask (1 for keep, 0 for padding) + src_tok_idxs: Source token indices (possibly corrupted) + tgt_tok_idxs: Target token indices (original) + src_tok_embs: Source token embeddings + is_corrupted: Mask indicating which tokens were corrupted + corrupt_frac: Fraction of tokens corrupted + """ + seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac = self.init_seq( + inputs, seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac, **kwargs + ) + ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) = self.tokenize_seq( + seq_array, + tgt_tok_idxs, + src_tok_embs, + padding_mask, + corrupt_frac, + is_corrupted, + corruption_allowed, + ) + + # Create HF attention mask (1 for keep, 0 for padding) from padding_mask + attention_mask = padding_mask.long() + + # Process through transformer model + if src_tok_idxs is not None: + outputs = self.transformer( + input_ids=src_tok_idxs, + attention_mask=attention_mask, + # No need to pass an explicit causal mask - HF decoder handles it internally + return_dict=True, + output_hidden_states=True, + ) + # Extract last hidden state + root_features = outputs.hidden_states[-1] + else: + # Handle the case when src_tok_embs is provided (less common for transformers) + outputs = self.transformer( + inputs_embeds=src_tok_embs, attention_mask=attention_mask, return_dict=True, output_hidden_states=True + ) + root_features = outputs.hidden_states[-1] + + # Make sure corrupt_frac is on the same device as other tensors + if isinstance(corrupt_frac, torch.Tensor): + corrupt_frac = corrupt_frac.to(root_features.device) + + outputs = TransformerDecoderRootOutput( + root_features=root_features.contiguous(), + padding_mask=padding_mask, + src_tok_idxs=src_tok_idxs, + tgt_tok_idxs=tgt_tok_idxs, + src_tok_embs=src_tok_embs, + is_corrupted=is_corrupted, + corrupt_frac=corrupt_frac, + ) + return outputs diff --git a/cortex/model/root/_transformer_encoder_root.py b/cortex/model/root/_transformer_encoder_root.py new file mode 100644 index 0000000..f7b4f78 --- /dev/null +++ b/cortex/model/root/_transformer_encoder_root.py @@ -0,0 +1,306 @@ +import warnings +from dataclasses import dataclass +from typing import Optional, Union + +import numpy as np +import torch +from omegaconf import DictConfig +from torch import LongTensor, nn +from transformers import AutoConfig, AutoModel + +from cortex.corruption import CorruptionProcess, MaskCorruptionProcess +from cortex.model.root import RootNode, RootNodeOutput +from cortex.transforms import HuggingFaceTokenizerTransform, PadTransform, ToTensor + + +@dataclass +class TransformerEncoderRootOutput(RootNodeOutput): + """Output of TransformerEncoderRoot.""" + + root_features: torch.Tensor + padding_mask: torch.Tensor + corrupt_frac: Optional[torch.Tensor] = None + src_tok_idxs: Optional[torch.LongTensor] = None + tgt_tok_idxs: Optional[torch.LongTensor] = None + src_tok_embs: Optional[torch.Tensor] = None + is_corrupted: Optional[torch.Tensor] = None + + +class TransformerEncoderRoot(RootNode): + """ + A root node that wraps a Hugging Face transformer encoder model (e.g., BERT, RoBERTa, LBSTER encoders). + + Example Hydra Config: + ```yaml + roots: + protein_encoder: + _target_: cortex.model.root.TransformerEncoderRoot + tokenizer_transform: ??? # Needs instantiation elsewhere + model_name_or_path: "facebook/esm2_t6_8M_UR50D" + use_pretrained: True + max_len: 512 + attn_implementation: "sdpa" + out_dim: 320 # Example, will be inferred + ``` + """ + + def __init__( + self, + tokenizer_transform: HuggingFaceTokenizerTransform, + model_name_or_path: str, + max_len: int, + out_dim: int = None, + use_pretrained: bool = True, + attn_implementation: Optional[str] = "sdpa", + config_overrides: Optional[DictConfig] = None, + corruption_process: Optional[CorruptionProcess] = None, + train_transforms=None, + eval_transforms=None, + **kwargs, + ) -> None: + super().__init__() + self.tokenizer = tokenizer_transform.tokenizer + self.pad_tok_idx = self.tokenizer.padding_idx + self.masking_idx = getattr(self.tokenizer, "masking_idx", None) + self.max_len = max_len + + # Load or create Hugging Face model configuration + if use_pretrained: + self.transformer = AutoModel.from_pretrained(model_name_or_path, attn_implementation=attn_implementation) + else: + config = AutoConfig.from_pretrained(model_name_or_path) + # Apply configuration overrides if specified + if config_overrides is not None: + for key, value in config_overrides.items(): + setattr(config, key, value) + self.transformer = AutoModel.from_config(config) + + # Determine output dimension from model + self._out_dim = self.transformer.config.hidden_size + + # Validate against provided out_dim if specified + if out_dim is not None and out_dim != self._out_dim: + warnings.warn( + f"Provided out_dim ({out_dim}) does not match model's hidden_size ({self._out_dim}). " + f"Using model's hidden_size.", + stacklevel=2, + ) + + # Set up transforms + shared_transforms = [ + tokenizer_transform, + ToTensor(padding_value=self.pad_tok_idx), + PadTransform(max_length=self.max_len, pad_value=self.pad_tok_idx), + ] + train_transforms = [] if train_transforms is None else list(train_transforms.values()) + eval_transforms = [] if eval_transforms is None else list(eval_transforms.values()) + self.train_transform = nn.Sequential(*(train_transforms + shared_transforms)) + self.eval_transform = nn.Sequential(*(eval_transforms + shared_transforms)) + + self.corruption_process = corruption_process + + @property + def out_dim(self): + return self._out_dim + + @property + def device(self): + return next(self.transformer.parameters()).device + + def initialize_weights(self, **kwargs): + # Default random initialization or handled by HF + pass + + def init_seq( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + corrupt_frac: float = 0.0, + **kwargs, + ): + # infer input type if not specified + if inputs is not None: + if isinstance(inputs, np.ndarray): + seq_array = inputs + if isinstance(inputs, LongTensor): + tgt_tok_idxs = inputs + elif isinstance(inputs, torch.Tensor): + src_tok_embs = inputs + msg = "inputs is deprecated, use a specific argument instead" + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + # Determine batch size from any available input + batch_size = None + if seq_array is not None: + batch_size = seq_array.shape[0] + elif tgt_tok_idxs is not None: + batch_size = tgt_tok_idxs.shape[0] + elif src_tok_embs is not None: + batch_size = src_tok_embs.shape[0] + + # Fallback to default batch size of 1 if no inputs are provided + if batch_size is None: + batch_size = 1 + + if "mask_frac" in kwargs: + corrupt_frac = kwargs["mask_frac"] + msg = "mask_frac is deprecated, use corrupt_frac instead." + warnings.warn(msg, PendingDeprecationWarning, stacklevel=2) + + if self.corruption_process is not None and corrupt_frac is None: + corrupt_frac = self.corruption_process.sample_corrupt_frac(n=batch_size).to(self.device) + elif isinstance(corrupt_frac, float): + corrupt_frac = torch.full((batch_size,), corrupt_frac, device=self.device) + elif isinstance(corrupt_frac, torch.Tensor): + # Move tensor to the correct device + corrupt_frac = corrupt_frac.to(self.device) + else: + corrupt_frac = torch.full((batch_size,), 0.0, device=self.device) + + return seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac + + def tokenize_seq( + self, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + ): + # begin forward pass from raw sequence + if seq_array is not None: + assert tgt_tok_idxs is None + assert src_tok_embs is None + if self.training: + tgt_tok_idxs = self.train_transform(seq_array) + else: + tgt_tok_idxs = self.eval_transform(seq_array) + tgt_tok_idxs = tgt_tok_idxs.to(self.device) + + # truncate token sequence to max context length + if tgt_tok_idxs is not None: + assert src_tok_embs is None + # truncate to max context length, keep final stop token + if tgt_tok_idxs.size(-1) > self.max_len: + tmp_tok_idxs = tgt_tok_idxs[..., : self.max_len - 1] + tgt_tok_idxs = torch.cat([tmp_tok_idxs, tgt_tok_idxs[..., -1:]], dim=-1) + + if corruption_allowed is None and tgt_tok_idxs is not None: + corruption_allowed = self.tokenizer.get_corruptible_mask(tgt_tok_idxs) + + # begin forward pass from tokenized sequence + if tgt_tok_idxs is not None: + # apply masking corruption + if isinstance(self.corruption_process, MaskCorruptionProcess) and ( + (isinstance(corrupt_frac, float) and corrupt_frac > 0.0) + or (isinstance(corrupt_frac, torch.Tensor) and torch.any(corrupt_frac > 0.0)) + ): + src_tok_idxs, is_corrupted = self.corruption_process( + x_start=tgt_tok_idxs, + mask_val=self.masking_idx or self.tokenizer.mask_token_id, + corruption_allowed=corruption_allowed, + corrupt_frac=corrupt_frac, + ) + else: + src_tok_idxs = tgt_tok_idxs + is_corrupted = ( + torch.full_like(src_tok_idxs, False, dtype=torch.bool) if is_corrupted is None else is_corrupted + ) + + padding_mask = src_tok_idxs != self.pad_tok_idx + + if src_tok_embs is not None: + assert seq_array is None + assert padding_mask is not None + src_tok_idxs = None + + return ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) + + def forward( + self, + inputs: Optional[Union[np.ndarray, torch.Tensor]] = None, + seq_array: Optional[np.ndarray] = None, + tgt_tok_idxs: Optional[LongTensor] = None, + src_tok_embs: Optional[torch.Tensor] = None, + padding_mask: Optional[torch.Tensor] = None, + corrupt_frac: Union[float, torch.Tensor] = 0.0, + is_corrupted: Optional[torch.Tensor] = None, + corruption_allowed: Optional[torch.Tensor] = None, + **kwargs, + ) -> TransformerEncoderRootOutput: + """ + Args: + seq_array: (batch_size,) array of discrete sequences (e.g. text strings) + tgt_tok_idxs: Optional pre-tokenized inputs + src_tok_embs: Optional pre-embedded inputs + padding_mask: Optional padding mask for pre-embedded inputs + corrupt_frac: Fraction of tokens to corrupt + is_corrupted: Optional pre-computed corruption mask + corruption_allowed: Optional mask of tokens that can be corrupted + + Returns: + TransformerEncoderRootOutput containing: + root_features: Transformer encoder output representations + padding_mask: Attention mask (1 for keep, 0 for padding) + src_tok_idxs: Source token indices (possibly corrupted) + tgt_tok_idxs: Target token indices (original) + src_tok_embs: Source token embeddings + is_corrupted: Mask indicating which tokens were corrupted + corrupt_frac: Fraction of tokens corrupted + """ + seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac = self.init_seq( + inputs, seq_array, tgt_tok_idxs, src_tok_embs, corrupt_frac, **kwargs + ) + ( + src_tok_idxs, + tgt_tok_idxs, + corruption_allowed, + is_corrupted, + padding_mask, + ) = self.tokenize_seq( + seq_array, + tgt_tok_idxs, + src_tok_embs, + padding_mask, + corrupt_frac, + is_corrupted, + corruption_allowed, + ) + + # Create HF attention mask (1 for keep, 0 for padding) from padding_mask + attention_mask = padding_mask.long() + + # Process through transformer model + if src_tok_idxs is not None: + outputs = self.transformer(input_ids=src_tok_idxs, attention_mask=attention_mask, return_dict=True) + # Extract last hidden state + root_features = outputs.last_hidden_state + else: + # Handle the case when src_tok_embs is provided (less common for transformers) + outputs = self.transformer(inputs_embeds=src_tok_embs, attention_mask=attention_mask, return_dict=True) + root_features = outputs.last_hidden_state + + # Make sure corrupt_frac is on the same device as other tensors + if isinstance(corrupt_frac, torch.Tensor): + corrupt_frac = corrupt_frac.to(root_features.device) + + outputs = TransformerEncoderRootOutput( + root_features=root_features.contiguous(), + padding_mask=padding_mask, + src_tok_idxs=src_tok_idxs, + tgt_tok_idxs=tgt_tok_idxs, + src_tok_embs=src_tok_embs, + is_corrupted=is_corrupted, + corrupt_frac=corrupt_frac, + ) + return outputs