-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy path3rdparty_nemo_evo2_tmp.patch
More file actions
22 lines (20 loc) · 1.07 KB
/
3rdparty_nemo_evo2_tmp.patch
File metadata and controls
22 lines (20 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
index 67500615e0..147e78cfa1 100644
--- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
+++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_utils.py
@@ -647,7 +647,7 @@ class ExplicitSingleDecayFilter(nn.Module):
"""
return self.filter(L, *args, **kwargs)
- @torch.compile(mode="max-autotune")
+ #@torch.compile(mode="max-autotune")
def filter(self, L, *args, **kwargs):
"""Compute the filter as a function of h and decay for the requested sequence length."""
h = self.h[:, :L]
@@ -834,7 +834,7 @@ class ParallelHyenaOperator(nn.Module):
self.conv_bias.data = conv_init_method(self.conv_bias.data)
self.conv_bias.model_parallel = True
self.conv_bias.partition_dim = 0
- self.conv_bias.stride = 1
+ #self.conv_bias.stride = 1
def forward_long(self, *, x1, x2, v, h, bias, inference_context):
"""Forward pass long."""