11from typing import Sequence
22import equinox as eqx
33from jaxtyping import Key
4+ import numpy as np
45import ml_collections
56
67from ._mixer import Mixer2d
78from ._mlp import ResidualNetwork
89from ._unet import UNet
9- from ._unet_xy import UNetXY
1010
1111
1212def get_model (
@@ -17,6 +17,12 @@ def get_model(
1717 parameter_dim : int ,
1818 config : ml_collections .ConfigDict
1919) -> eqx .Module :
20+ # Grab channel assuming 'q' is a map like x
21+ if context_shape is not None :
22+ context_channels , * _ = context_shape .shape
23+ else :
24+ context_channels = None
25+
2026 if model_type == "Mixer" :
2127 model = Mixer2d (
2228 data_shape ,
@@ -26,7 +32,7 @@ def get_model(
2632 mix_hidden_size = config .model .mix_hidden_size ,
2733 num_blocks = config .model .num_blocks ,
2834 t1 = config .t1 ,
29- q_dim = context_shape ,
35+ q_dim = context_channels ,
3036 a_dim = parameter_dim ,
3137 key = model_key
3238 )
@@ -42,22 +48,7 @@ def get_model(
4248 num_res_blocks = config .model .num_res_blocks ,
4349 attn_resolutions = config .model .attn_resolutions ,
4450 final_activation = config .model .final_activation ,
45- a_dim = parameter_dim ,
46- key = model_key
47- )
48- if model_type == "UNetXY" :
49- model = UNetXY (
50- data_shape = data_shape ,
51- is_biggan = config .model .is_biggan ,
52- dim_mults = config .model .dim_mults ,
53- hidden_size = config .model .hidden_size ,
54- heads = config .model .heads ,
55- dim_head = config .model .dim_head ,
56- dropout_rate = config .model .dropout_rate ,
57- num_res_blocks = config .model .num_res_blocks ,
58- attn_resolutions = config .model .attn_resolutions ,
59- final_activation = config .model .final_activation ,
60- q_dim = context_shape [0 ], # Just grab channel assuming 'q' is a map like x
51+ q_dim = context_channels ,
6152 a_dim = parameter_dim ,
6253 key = model_key
6354 )
@@ -68,9 +59,11 @@ def get_model(
6859 depth = config .model .depth ,
6960 activation = config .model .activation ,
7061 dropout_p = config .model .dropout_p ,
71- y_dim = parameter_dim ,
62+ q_dim = parameter_dim ,
7263 key = model_key
7364 )
65+ if model_type == "CCT" :
66+ raise NotImplementedError
7467 if model_type == "DiT" :
7568 raise NotImplementedError
7669 return model
0 commit comments