|
| 1 | +import torch |
| 2 | + |
| 3 | +from aion.codecs.modules.utils import LayerNorm, GRN |
| 4 | + |
| 5 | + |
| 6 | +class ConvNextBlock1d(torch.nn.Module): |
| 7 | + """ConvNeXtV2 Block. |
| 8 | + Modified to 1D from the original 2D implementation from https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py |
| 9 | +
|
| 10 | + Args: |
| 11 | + dim (int): Number of input channels. |
| 12 | + drop_path (float): Stochastic depth rate. Default: 0.0 |
| 13 | + """ |
| 14 | + |
| 15 | + def __init__(self, dim: int): |
| 16 | + super().__init__() |
| 17 | + self.dwconv = torch.nn.Conv1d( |
| 18 | + dim, dim, kernel_size=7, padding=3, groups=dim |
| 19 | + ) # depthwise conv |
| 20 | + self.norm = LayerNorm(dim, eps=1e-6) |
| 21 | + self.pwconv1 = torch.nn.Linear( |
| 22 | + dim, 4 * dim |
| 23 | + ) # pointwise/1x1 convs, implemented with linear layers |
| 24 | + self.act = torch.nn.GELU() |
| 25 | + self.grn = GRN(4 * dim) |
| 26 | + self.pwconv2 = torch.nn.Linear(4 * dim, dim) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + y = self.dwconv(x) |
| 30 | + y = y.permute(0, 2, 1) # (B, C, N) -> (B, N, C) |
| 31 | + y = self.norm(y) |
| 32 | + y = self.pwconv1(y) |
| 33 | + y = self.act(y) |
| 34 | + y = self.grn(y) |
| 35 | + y = self.pwconv2(y) |
| 36 | + y = y.permute(0, 2, 1) # (B, N, C) -> (B, C, N) |
| 37 | + |
| 38 | + y = x + y |
| 39 | + return y |
| 40 | + |
| 41 | + |
| 42 | +class ConvNextEncoder1d(torch.nn.Module): |
| 43 | + r"""ConvNeXt encoder. |
| 44 | +
|
| 45 | + Modified from https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py |
| 46 | +
|
| 47 | + Args: |
| 48 | + in_chans : Number of input image channels. Default: 3 |
| 49 | + depths : Number of blocks at each stage. Default: [3, 3, 9, 3] |
| 50 | + dims : Feature dimension at each stage. Default: [96, 192, 384, 768] |
| 51 | + drop_path_rate : Stochastic depth rate. Default: 0. |
| 52 | + layer_scale_init_value : Init value for Layer Scale. Default: 1e-6. |
| 53 | + """ |
| 54 | + |
| 55 | + def __init__( |
| 56 | + self, |
| 57 | + in_chans: int = 2, |
| 58 | + depths: tuple[int, ...] = (3, 3, 9, 3), |
| 59 | + dims: tuple[int, ...] = (96, 192, 384, 768), |
| 60 | + ): |
| 61 | + super().__init__() |
| 62 | + assert len(depths) == len(dims), "depths and dims should have the same length" |
| 63 | + num_layers = len(depths) |
| 64 | + |
| 65 | + self.downsample_layers = ( |
| 66 | + torch.nn.ModuleList() |
| 67 | + ) # stem and 3 intermediate downsampling conv layers |
| 68 | + stem = torch.nn.Sequential( |
| 69 | + torch.nn.Conv1d(in_chans, dims[0], kernel_size=4, stride=4), |
| 70 | + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), |
| 71 | + ) |
| 72 | + self.downsample_layers.append(stem) |
| 73 | + for i in range(num_layers - 1): |
| 74 | + downsample_layer = torch.nn.Sequential( |
| 75 | + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), |
| 76 | + torch.nn.Conv1d(dims[i], dims[i + 1], kernel_size=2, stride=2), |
| 77 | + ) |
| 78 | + self.downsample_layers.append(downsample_layer) |
| 79 | + |
| 80 | + self.stages = torch.nn.ModuleList() |
| 81 | + for i in range(num_layers): |
| 82 | + stage = torch.nn.Sequential( |
| 83 | + *[ |
| 84 | + ConvNextBlock1d( |
| 85 | + dim=dims[i], |
| 86 | + ) |
| 87 | + for j in range(depths[i]) |
| 88 | + ] |
| 89 | + ) |
| 90 | + self.stages.append(stage) |
| 91 | + |
| 92 | + self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first") |
| 93 | + |
| 94 | + self.apply(self._init_weights) |
| 95 | + |
| 96 | + def _init_weights(self, m): |
| 97 | + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): |
| 98 | + torch.nn.init.trunc_normal_(m.weight, std=0.02) |
| 99 | + torch.nn.init.constant_(m.bias, 0) |
| 100 | + |
| 101 | + def forward(self, x): |
| 102 | + for ds, st in zip(self.downsample_layers, self.stages): |
| 103 | + x = ds(x) |
| 104 | + x = st(x) |
| 105 | + return self.norm(x) |
| 106 | + |
| 107 | + |
| 108 | +class ConvNextDecoder1d(torch.nn.Module): |
| 109 | + r"""ConvNeXt decoder. Essentially a mirrored version of the encoder. |
| 110 | +
|
| 111 | + Args: |
| 112 | + in_chans (int): Number of input image channels. Default: 3 |
| 113 | + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] |
| 114 | + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] |
| 115 | + drop_path_rate (float): Stochastic depth rate. Default: 0. |
| 116 | + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. |
| 117 | + """ |
| 118 | + |
| 119 | + def __init__( |
| 120 | + self, |
| 121 | + in_chans=768, |
| 122 | + depths=[3, 3, 9, 3], |
| 123 | + dims=[384, 192, 96, 2], |
| 124 | + ): |
| 125 | + super().__init__() |
| 126 | + assert len(depths) == len(dims), "depths and dims should have the same length" |
| 127 | + num_layers = len(depths) |
| 128 | + |
| 129 | + self.upsample_layers = torch.nn.ModuleList() |
| 130 | + |
| 131 | + stem = torch.nn.Sequential( |
| 132 | + torch.nn.ConvTranspose1d(in_chans, dims[0], kernel_size=2, stride=2), |
| 133 | + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), |
| 134 | + ) |
| 135 | + self.upsample_layers.append(stem) |
| 136 | + |
| 137 | + for i in range(num_layers - 1): |
| 138 | + upsample_layer = torch.nn.Sequential( |
| 139 | + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), |
| 140 | + torch.nn.ConvTranspose1d( |
| 141 | + dims[i], |
| 142 | + dims[i + 1], |
| 143 | + kernel_size=2 if i < (num_layers - 2) else 4, |
| 144 | + stride=2 if i < (num_layers - 2) else 4, |
| 145 | + ), |
| 146 | + ) |
| 147 | + self.upsample_layers.append(upsample_layer) |
| 148 | + |
| 149 | + self.stages = torch.nn.ModuleList() |
| 150 | + for i in range(num_layers): |
| 151 | + stage = torch.nn.Sequential( |
| 152 | + *[ |
| 153 | + ConvNextBlock1d( |
| 154 | + dim=dims[i], |
| 155 | + ) |
| 156 | + for j in range(depths[i]) |
| 157 | + ] |
| 158 | + ) |
| 159 | + self.stages.append(stage) |
| 160 | + |
| 161 | + self.apply(self._init_weights) |
| 162 | + |
| 163 | + def _init_weights(self, m): |
| 164 | + if isinstance(m, (torch.nn.Conv1d, torch.nn.Linear)): |
| 165 | + torch.nn.init.trunc_normal_(m.weight, std=0.02) |
| 166 | + torch.nn.init.constant_(m.bias, 0) |
| 167 | + |
| 168 | + def forward(self, x): |
| 169 | + for us, st in zip(self.upsample_layers, self.stages): |
| 170 | + x = us(x) |
| 171 | + x = st(x) |
| 172 | + return x |
0 commit comments