Skip to content

Moment infini #76

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: moment_infini
Choose a base branch
from

Conversation

PotosnakW
Copy link

@PotosnakW PotosnakW commented Mar 27, 2025

New files:

  • momentfm/models/moment.py: supports infini channel mixing with config.infini_channel_mixing boolean and config.n_series (number of individual time series). Only support forecasting currently.
  • momentfm/utils.t5_infini.py: contains 'T5InfiniModel' class. if config.infini_channel_mixing==True then T5InfiniAttention is used, else the default T5Attention is used.

Copy link
Collaborator

@JanekDev JanekDev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the implementation with the paper and everything looks good! I left some minor stylistic comments + a small comment on positional bias. Additionally, it's good that infini-moment was moved to the other file.


x_enc = self.tokenizer(x=x_enc)
batch_size, n_channels, seq_len = x_enc.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest unifying n_channels and n_series.

x: [batch_size x n_channels x n_patches x d_model]
output: [batch_size x n_channels x forecast_horizon]
"""
x = self.flatten(x) # x: [batch_size, n_series, n_patches, d_model]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggesting unification of n_channels and n_series

if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_channels, self.n_heads, seq_length, key_length), device=hidden_states.device, dtype=hidden_states.dtype
) # Willa - should we use n_channels or just 1?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the original implementation by Nina there is no channel axis, so it gets probably broadcasted and position biases are shared between channels, hence there should be probably 1?

# Vectorized infini attention computation across channels
sigma_k = self.elu(key_states) + 1.0 # [batch_size, n_series, n_heads, n_patch, dim]
sigma_k_transposed = sigma_k.transpose(-2, -1) # [batch_size, n_series, n_heads, dim, n_patch]
memory_matrix = torch.matmul(sigma_k_transposed, value_states).sum(dim=1).unsqueeze(1) # [batch_size, 1, n_heads, dim, dim] sum over channels then unsqueeze to enable broadcasting over channels
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the purpose of making it easier to understand- can we split the computation of memory matrix into memory updates and only then sum them in the separate line? Implementation looks correct btw!

z = sigma_k.sum(dim=-2).unsqueeze(-1).sum(dim=1) # [batch_size, n_heads, dim, 1] sum over sequence length and channels
z = z.unsqueeze(dim=1) # [batch_size, 1, n_heads, dim, 1]
sigma_q = self.elu(query_states) + 1.0 # [batch_size, n_series, n_heads, n_patch, dim]
A_mem = (sigma_q @ memory_matrix) / ((sigma_q @ z) + 1e-6) # [batch_size, n_series, n_heads, n_patch, dim]/[batch_size, n_series, n_heads, n_patch, 1] --> [batch_size, n_series, n_heads, n_patch, dim] Adding 1e-6 for preventing division to 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe split this too?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants