diff --git a/momentfm/models/infini_moment.py b/momentfm/models/infini_moment.py new file mode 100644 index 0000000..93dc70f --- /dev/null +++ b/momentfm/models/infini_moment.py @@ -0,0 +1,208 @@ +import logging +import warnings +from argparse import Namespace +from math import ceil + +import torch +from torch import nn +from transformers import T5Config + +from momentfm.data.base import TimeseriesOutputs +from momentfm.models.layers.embed import PatchEmbedding, Patching +from momentfm.models.layers.revin import RevIN +from momentfm.utils.masking import Masking +from momentfm.utils.utils.utils import NamespaceWithDefaults, _update_inputs, _validate_inputs +from momentfm.utils.t5_infini import T5InfiniModel, T5InfiniEncoderModel + +logger = logging.getLogger(__name__) + + +class ForecastingHead(nn.Module): + def __init__(self, + head_nf: int = 768*64, + forecast_horizon: int = 96, + c_out: int = 1, + head_dropout: int = 0): + super().__init__() + self.flatten = nn.Flatten(start_dim=-2) + self.dropout = nn.Dropout(head_dropout) + self.linear = nn.Linear(head_nf, forecast_horizon * c_out) # NEW: c_out for loss dimension (potential for probabilistic predictions) + + def forward(self, x, input_mask : torch.Tensor = None): + """ + x: [batch_size x n_channels x n_patches x d_model] + output: [batch_size x n_channels x forecast_horizon] + """ + x = self.flatten(x) # x: [batch_size, n_channels, n_patches, d_model] + x = self.linear(x) # x: [batch_size, n_channels, n_patches*d_model] + x = self.dropout(x) # x: [batch_size, n_channels, horizon*c_out] + return x + +class Long_Forecaster(nn.Module): + + def __init__(self, config): + + super().__init__() + + self.d_model = config.d_model + self.patch_len = config.patch_len + self.stride = config.stride + self.transformer_type = config.transformer_type + + self.revin = config.revin + if config.revin: + self.normalizer = RevIN( + num_features=config.n_channels, + affine=config.revin_affine + ) + + self.tokenizer = Patching( + patch_len=config.patch_len, + stride=config.stride, + ) + self.patch_embedding = PatchEmbedding( + d_model=config.d_model, + seq_len=config.input_size, + patch_len=config.patch_len, + stride=config.stride, + dropout=config.dropout, + add_positional_embedding=True, + value_embedding_bias=False, + orth_gain=1.41, + ) + self.mask_generator = Masking(mask_ratio=0.0) # no masking for forecasting task + + # Transformer backbone + self.encoder = self._get_huggingface_transformer(config) + + # Prediction Head + num_patches = ( + (max(config.input_size, config.patch_len) - config.patch_len) + // config.stride + 1 + ) + + head_nf = config.d_model * num_patches + self.head = ForecastingHead( + head_nf, + config.h, + config.c_out, + config.head_dropout, + ) + + def _get_huggingface_transformer(self, configs): + ModelClass, EncoderModelClass = T5InfiniModel, T5InfiniEncoderModel + + logger.info(f" ModelClass: {ModelClass.__name__}, EncoderModelClass: {EncoderModelClass.__name__}.") + + model_config = T5Config.from_pretrained( + configs.transformer_backbone) + + setattr(model_config, 'infini_channel_mixing', configs.infini_channel_mixing) + setattr(model_config, 'use_rope', configs.use_rope) + setattr(model_config, 'max_sequence_length', configs.input_size / configs.patch_len) + setattr(model_config, 'n_channels', configs.n_channels) + + transformer_backbone = ModelClass(model_config) + logging.info(f"Initializing randomly initialized\ + transformer from {configs.transformer_backbone}. ModelClass: {ModelClass.__name__}.") + + transformer_backbone = transformer_backbone.get_encoder() #check valid inputs to raise error if not encoder-only + + if configs.getattr('enable_gradient_checkpointing', True): + transformer_backbone.gradient_checkpointing_enable() + logging.info("Enabling gradient checkpointing.") + + return transformer_backbone + + def forward(self, + *, + x_enc: torch.Tensor, + input_mask: torch.Tensor = None, + **kwargs + ) -> TimeseriesOutputs: + """ + x_enc : [batch_size x n_channels x seq_len] + input_mask : [batch_size x seq_len] + """ + + batch_size, n_channels, seq_len = x_enc.shape + input_mask = torch.ones(batch_size, seq_len).to(x_enc.device) # [batch_size, seq_len] + + # Normalization + if self.revin: + x_enc = self.normalizer(x=x_enc, mask=input_mask, mode='norm') + x_enc = torch.nan_to_num(x_enc, nan=0, posinf=0, neginf=0) + + # Patching and embedding + x_enc = self.tokenizer(x=x_enc) # [batch_size x n_channels x n_patch x patch_len] + enc_in = self.patch_embedding(x_enc, mask=torch.ones_like(input_mask)) + + n_patches = enc_in.shape[2] + enc_in = enc_in.reshape( + (batch_size * n_channels, n_patches, self.d_model)) # [batch_size*n_channels, n_patch, d_model] + + # Encoder + attention_mask = Masking.convert_seq_to_patch_view( + mask=input_mask, + patch_len=self.patch_len, + stride=self.stride).repeat_interleave(n_channels, dim=0) # [batch_size*n_channels, n_patch] + + outputs = self.encoder(inputs_embeds=enc_in, attention_mask=attention_mask) + enc_out = outputs.last_hidden_state + + enc_out = enc_out.reshape( + (-1, n_channels, n_patches, self.d_model)) + # [batch_size, n_channels, n_patch, d_model] + + # Decoder + dec_out = self.head(enc_out) # [batch_size, n_channels, horizon*c_out] + + # De-Normalization + if self.revin: + dec_out = self.normalizer(x=dec_out, mode='denorm') # [batch_size, n_channels, horizon*c_out] + + return TimeseriesOutputs(input_mask=input_mask, forecast=dec_out) + +class MOMENT(nn.Module): + def __init__(self, config: Namespace | dict, **kwargs: dict): + super().__init__() + + if isinstance(config, (argparse.Namespace, SimpleNamespace)): + + elif isinstance(config, dict): + config['c_out'] = 1 + + config = _update_inputs(config) + config = _validate_inputs(config) + setattr(config, 'c_out', 1) #self.loss.outputsize_multiplier --> NEW: c_out for loss dimension (potential for probabilistic predictions) + self.h = config.h + self.input_size = config.input_size + + # Channel-independent: n_channels=1, Channel_dependent/multivariate prediction: n_channels=n_channels + if not hasattr(config, 'n_channels'): + raise AttributeError("config is missing required (int) attribute 'n_channels'") + if not hasattr(config, 'infini_channel_mixing'): + raise AttributeError("config is missing required (bool) attribute 'infini_channel_mixing'") + + if config.infini_channel_mixing==False: + setattr(config, 'n_channels', 1) + + if config.task_name == 'forecasting': + self.model = Long_Forecaster(config) + else: + raise NotImplementedError(f"Task {task_name} not implemented.") + + def forward( + self, + *, + x_enc: torch.Tensor, + input_mask: torch.Tensor = None, + mask: torch.Tensor = None, + **kwargs, + ) -> TimeseriesOutputs: + + #x_enc: [batch_size, n_channels, seq_len] + if input_mask is None: + input_mask = torch.ones_like(x_enc[:, 0, :]) + + return self.model(x_enc=x_enc, mask=mask, input_mask=input_mask, **kwargs) #dec_out: [batch_size, n_channels, horizon*c_out] diff --git a/momentfm/utils/t5_infini.py b/momentfm/utils/t5_infini.py new file mode 100644 index 0000000..695728e --- /dev/null +++ b/momentfm/utils/t5_infini.py @@ -0,0 +1,694 @@ +import math +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint +from typing import Optional + +from transformers.models.t5.modeling_t5 import T5Stack, T5Block, T5LayerNorm, T5Model, T5Config, T5EncoderModel, T5LayerCrossAttention, T5LayerSelfAttention, T5LayerFF + +from transformers.modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, +) + +class T5Attention(nn.Module): # Default T5Attention copied from HuggingFace for version control + def __init__( + self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) + return values + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] + + if mask is not None: + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + + # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = torch.matmul(attn_weights, value_states) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) + attn_output = self.o(attn_output) + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + +class T5InfiniAttention(nn.Module): + def __init__(self, + config: T5Config, + has_relative_attention_bias=False, + layer_idx: Optional[int] = None, + **kwargs + ): + super().__init__() + self.is_decoder = config.is_decoder + self.has_relative_attention_bias = has_relative_attention_bias + self.relative_attention_num_buckets = config.relative_attention_num_buckets + self.relative_attention_max_distance = config.relative_attention_max_distance + self.d_model = config.d_model + self.key_value_proj_dim = config.d_kv + self.n_heads = config.num_heads + self.dropout = config.dropout_rate + self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + # Mesh TensorFlow initialization to avoid scaling before softmax + self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.k = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.v = nn.Linear(self.d_model, self.inner_dim, bias=False) + self.o = nn.Linear(self.inner_dim, self.d_model, bias=False) + + if self.has_relative_attention_bias: + self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) + self.pruned_heads = set() + self.gradient_checkpointing = False + + # new parameters for infini_channel_mixing + self.use_rope = config.use_rope + self.elu = nn.ELU() + self.n_channels = config.n_channels + # check on beta initialization --> people have used zeros and random, which one is best? + self.beta = nn.Parameter(torch.rand((1, 1, self.n_heads, 1, 1))*1e-2) # Ablation exps: make C=n_channels for channel specific beta, can implement lasso for spasifying beta parameters, beta(s) for n_channels + # Adjust the values to ensure they sum to 0 --> CHECK THIS: we shouldn't need to do this because torch.rand samples from a normal distribution + with torch.no_grad(): + self.beta -= self.beta.mean() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices( + heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads + ) + # Prune linear layers + self.q = prune_linear_layer(self.q, index) + self.k = prune_linear_layer(self.k, index) + self.v = prune_linear_layer(self.v, index) + self.o = prune_linear_layer(self.o, index, dim=1) + + # Update hyper params + self.n_heads = self.n_heads - len(heads) + self.inner_dim = self.key_value_proj_dim * self.n_heads + self.pruned_heads = self.pruned_heads.union(heads) + + @staticmethod + def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + + Translate relative position to a bucket number for relative attention. The relative position is defined as + memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for + small absolute relative_position and larger buckets for larger absolute relative_positions. All relative + positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket. + This should allow for more graceful generalization to longer sequences than the model has been trained on + + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + + Returns: + a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets) + """ + relative_buckets = 0 + if bidirectional: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + def compute_bias(self, query_length, key_length, device=None, cache_position=None): + """Compute binned relative position bias""" + if device is None: + device = self.relative_attention_bias.weight.device + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + + memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] + relative_position = memory_position - context_position # shape (query_length, key_length) + relative_position_bucket = self._relative_position_bucket( + relative_position, # shape (query_length, key_length) + bidirectional=(not self.is_decoder), + num_buckets=self.relative_attention_num_buckets, + max_distance=self.relative_attention_max_distance, + ) + values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads) + values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0) # shape (1, 1, num_heads, query_length, key_length) --> NEW: added dimension=1 for n_channels + return values + + def _update_memory_matrix(self, key_states, value_states): + sigma_k = self.elu(key_states) + 1.0 # [batch_size, n_channels, n_heads, n_patch, dim] + sigma_k_transposed = sigma_k.transpose(-2, -1) # [batch_size, n_channels, n_heads, dim, n_patch] + + memory_matrix = torch.matmul(sigma_k_transposed, value_states).sum(dim=1).unsqueeze(1) # [batch_size, 1, n_heads, dim, dim] sum over channels then unsqueeze to enable broadcasting over channels + + z = sigma_k.sum(dim=-2).unsqueeze(-1).sum(dim=1) # [batch_size, n_heads, dim, 1] sum over sequence length and channels + z = z.unsqueeze(dim=1) # [batch_size, 1, n_heads, dim, 1] + + return memory_matrix, z + + def _retrieve_from_memory(self, query_states, memory_matrix, z): + sigma_q = self.elu(query_states) + 1.0 # [batch_size, n_channels, n_heads, n_patch, dim] + A_mem = (sigma_q @ memory_matrix) / ((sigma_q @ z) + 1e-6) # [batch_size, n_channels, n_heads, n_patch, dim]/[batch_size, n_channels, n_heads, n_patch, 1] --> [batch_size, n_channels, n_heads, n_patch, dim] Adding 1e-6 for preventing division to 0 + + return A_mem + + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, n_patch, key_length) (causal decoder) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None + + query_states = self.q(hidden_states) + query_states = query_states.view(batch_size, + -1, + self.n_heads, + self.key_value_proj_dim).transpose(1, 2) # [batch_size, n_heads, n_patch, dim] + query_states = query_states.view(batch_size//self.n_channels, + self.n_channels, + self.n_heads, + seq_length, + self.key_value_proj_dim) # [batch_size, n_channels, n_heads, n_patch, dim] + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: + # reuse k,v, cross_attentions + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] + else: + key_states = self.k(current_states) + value_states = self.v(current_states) + key_states = key_states.view(batch_size, + -1, + self.n_heads, + self.key_value_proj_dim).transpose(1, 2) + key_states = key_states.view(batch_size//self.n_channels, + self.n_channels, + self.n_heads, + seq_length, + self.key_value_proj_dim) # [batch_size, n_channels, n_heads, n_patch, dim] + value_states = value_states.view(batch_size, + -1, + self.n_heads, + self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size//self.n_channels, + self.n_channels, + self.n_heads, + seq_length, + self.key_value_proj_dim) # [batch_size, n_channels, n_heads, n_patch, dim] + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + if position_bias is None: + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, 1, self.n_heads, seq_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype + ) # NEW: added dim(1) for n_channels + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias( + real_seq_length, key_length, device=hidden_states.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, :, -seq_length:, :] + + if mask is not None: + #causal_mask = mask[:, :, :, :, : key_states.shape[-2]] + #position_bias = position_bias + causal_mask + mask = mask.view(batch_size//self.n_channels, self.n_channels, 1, seq_length, 1) # [batch_size, n_channels, 1, n_patch, 1] + position_bias = position_bias + mask + + if self.pruned_heads: + head_mask = torch.ones(position_bias.shape[1]) + head_mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, head_mask.bool()] + else: + position_bias_masked = position_bias + + # Infini attention computation across channels + memory_matrix, z = self._update_memory_matrix(key_states, value_states) + A_mem = self._retrieve_from_memory(query_states, memory_matrix, z) + + scores = query_states @ key_states.transpose(-2, -1) # [batch_size, n_channels, n_heads, n_patch, n_patch] + scores += position_bias_masked # [batch_size, n_channels, n_heads, n_patch, n_patch] + + scores = scores / torch.sqrt(torch.tensor(self.key_value_proj_dim, + device=hidden_states.device, + dtype=torch.float16 + ) + ) # [batch_size, n_channels, n_heads, n_patch, n_patch] + + attn_weights = F.softmax(scores, dim=-1) # [batch_size, n_channels, n_heads, n_patch, n_patch] + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # [batch_size, n_channels, n_heads, n_patch, n_patch] + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = attn_weights @ value_states # [batch_size, n_channels, n_heads, n_patch, dim] + + attn_output = F.sigmoid(self.beta) * A_mem + (1 - F.sigmoid(self.beta)) * attn_output # [batch_size, n_channels, n_heads, n_patch, dim] + + attn_output = attn_output.transpose(2, 3).contiguous() # [batch_size, n_channels, n_patch, n_heads, dim] + attn_output = attn_output.view(batch_size, -1, self.inner_dim) # [batch_size*n_channels, n_patch, n_heads*dim] + attn_output = self.o(attn_output) # [batch_size*n_channels, n_patch, n_heads*dim] + + outputs = (attn_output, past_key_value, position_bias) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + +class T5LayerSelfInfiniAttention(nn.Module): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__() + + if config.infini_channel_mixing: + self.SelfAttention = T5InfiniAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + else: + self.SelfAttention = T5Attention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + ) + hidden_states = hidden_states + self.dropout(attention_output[0]) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + +class T5LayerCrossInfiniAttention(nn.Module): + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + + if config.infini_channel_mixing: + self.EncDecAttention = T5InfiniAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + else: + self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx) + self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + cache_position=None, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + cache_position=cache_position, + ) + layer_output = hidden_states + self.dropout(attention_output[0]) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + +class T5InfiniBlock(T5Block): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): + super().__init__(config) + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append( + T5LayerSelfInfiniAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx) + ) + if self.is_decoder: + self.layer.append(T5LayerCrossInfiniAttention(config, layer_idx=layer_idx)) + + self.layer.append(T5LayerFF(config)) + +class T5InfiniStack(T5Stack): + def __init__(self, config, embed_tokens=None): + super().__init__(config) + + self.embed_tokens = embed_tokens + self.is_decoder = config.is_decoder + + self.block = nn.ModuleList( + [T5InfiniBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)] + ) + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.model_parallel = False + self.device_map = None + self.gradient_checkpointing = False + +class T5InfiniModel(T5Model): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5InfiniStack(encoder_config, self.shared) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5InfiniStack(decoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + +class T5InfiniEncoderModel(T5EncoderModel): + _tied_weights_keys = ["encoder.embed_tokens.weight"] + + def __init__(self, config: T5Config): + super().__init__(config) + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5InfiniStack(encoder_config, self.shared) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None diff --git a/momentfm/utils/utils.py b/momentfm/utils/utils.py index ede4c7b..1749233 100644 --- a/momentfm/utils/utils.py +++ b/momentfm/utils/utils.py @@ -6,25 +6,53 @@ import numpy as np import torch - -class NamespaceWithDefaults(Namespace): +SUPPORTED_HUGGINGFACE_MODELS = [ + 't5-small', 't5-base', 't5-large', 't5-3b', 't5-11b', + 'google/flan-t5-small', 'google/flan-t5-base', + 'google/flan-t5-large', 'google/flan-t5-xl', + 'google/flan-t5-xxl', + 'google/t5-efficient-tiny', 'google/t5-efficient-mini', + 'google/t5-efficient-small', 'google/t5-efficient-medium', + 'google/t5-efficient-large', 'google/t5-efficient-base', +] + +# class NamespaceWithDefaults(Namespace): +# @classmethod +# def from_namespace(cls, namespace): +# new_instance = cls() +# for attr in dir(namespace): +# if not attr.startswith("__"): +# setattr(new_instance, attr, getattr(namespace, attr)) +# return new_instance + +# def getattr(self, key, default=None): +# return getattr(self, key, default) + +class NamespaceWithDefaults(Namespace): # NEW: for converting dictionary @classmethod def from_namespace(cls, namespace): new_instance = cls() - for attr in dir(namespace): - if not attr.startswith("__"): - setattr(new_instance, attr, getattr(namespace, attr)) - return new_instance + if isinstance(namespace, dict): + # Handle the case where namespace is a dictionary + for key, value in namespace.items(): + setattr(new_instance, key, value) + + elif isinstance(namespace, Namespace): + # Handle the case where namespace is a Namespace object + for attr in dir(namespace): + if not attr.startswith('__'): + setattr(new_instance, attr, getattr(namespace, attr)) + + return new_instance + def getattr(self, key, default=None): return getattr(self, key, default) - def parse_config(config: dict) -> NamespaceWithDefaults: args = NamespaceWithDefaults(**config) return args - def make_dir_if_not_exists(path, verbose=True): if not is_directory(path): path = path.split(".")[0] @@ -34,7 +62,6 @@ def make_dir_if_not_exists(path, verbose=True): print(f"Making directory: {path}...") return True - def is_directory(path): extensions = [".pth", ".txt", ".json", ".yaml"] @@ -43,7 +70,6 @@ def is_directory(path): return False return True - def control_randomness(seed: int = 13): random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) @@ -53,7 +79,6 @@ def control_randomness(seed: int = 13): torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - def dtype_map(dtype: str): map = { "float16": torch.float16, @@ -69,13 +94,38 @@ def dtype_map(dtype: str): } return map[dtype] - -def get_huggingface_model_dimensions(model_name: str = "flan-t5-base"): +def get_huggingface_model_dimensions(model_name : str = "flan-t5-base"): from transformers import T5Config - config = T5Config.from_pretrained(model_name) return config.d_model +def _update_inputs(configs: Namespace | dict, **kwargs) -> NamespaceWithDefaults: + if isinstance(configs, dict) and 'model_kwargs' in kwargs: + return NamespaceWithDefaults(**{**configs, **kwargs['model_kwargs']}) + else: + return NamespaceWithDefaults.from_namespace(configs) + +def _validate_inputs(configs: NamespaceWithDefaults) -> NamespaceWithDefaults: + if configs.transformer_backbone == "PatchTST" and configs.transformer_type != "encoder_only": + warnings.warn("PatchTST only supports encoder-only transformer backbones.") + configs.transformer_type = "encoder_only" + if configs.transformer_backbone != "PatchTST" and configs.transformer_backbone not in SUPPORTED_HUGGINGFACE_MODELS: + raise NotImplementedError(f"Transformer backbone {configs.transformer_backbone} not supported." + f"Please choose from {SUPPORTED_HUGGINGFACE_MODELS} or PatchTST.") + if configs.d_model is None and configs.transformer_backbone in SUPPORTED_HUGGINGFACE_MODELS: + configs.d_model = get_huggingface_model_dimensions(configs.transformer_backbone) + logging.info("Setting d_model to {}".format(configs.d_model)) + elif configs.d_model is None: + raise ValueError("d_model must be specified if transformer backbone \ + unless transformer backbone is a Huggingface model.") + + if configs.transformer_type not in ["encoder_only", "decoder_only", "encoder_decoder"]: + raise ValueError("transformer_type must be one of ['encoder_only', 'decoder_only', 'encoder_decoder']") + + if configs.stride != configs.patch_len: + warnings.warn("Patch stride length is not equal to patch length.") + + return configs def get_anomaly_criterion(anomaly_criterion: str = "mse"): if anomaly_criterion == "mse": @@ -85,7 +135,6 @@ def get_anomaly_criterion(anomaly_criterion: str = "mse"): else: raise ValueError(f"Anomaly criterion {anomaly_criterion} not supported.") - def _reduce(metric, reduction="mean", axis=None): if reduction == "mean": return np.nanmean(metric, axis=axis) @@ -94,7 +143,6 @@ def _reduce(metric, reduction="mean", axis=None): elif reduction == "none": return metric - class EarlyStopping: def __init__(self, patience: int = 3, verbose: bool = False, delta: float = 0): self.patience = patience