@@ -35,7 +35,7 @@ def lens_to_mask(
3535 length : int | None = None ,
3636) -> mx .array : # Bool['b n']
3737 if not exists (length ):
38- length = t .amax ()
38+ length = t .max ()
3939
4040 seq = mx .arange (length )
4141 return einx .less ("n, b -> b n" , seq , t )
@@ -844,7 +844,6 @@ def __init__(
844844 self ,
845845 transformer : dict | Transformer = None ,
846846 duration_predictor : dict | DurationPredictor | None = None ,
847- odeint_kwargs : dict = dict (atol = 1e-5 , rtol = 1e-5 , method = "midpoint" ),
848847 cond_drop_prob = 0.25 ,
849848 num_channels = None ,
850849 mel_spec_module : nn .Module | None = None ,
@@ -877,10 +876,6 @@ def __init__(
877876
878877 self .duration_predictor = duration_predictor
879878
880- # sampling
881-
882- self .odeint_kwargs = odeint_kwargs
883-
884879 # mel spec
885880
886881 self .mel_spec = default (mel_spec_module , MelSpec (** mel_spec_kwargs ))
@@ -1016,7 +1011,6 @@ def odeint(self, func, y0, t):
10161011
10171012 return mx .stack (ys )
10181013
1019- @mx .compile
10201014 def sample (
10211015 self ,
10221016 cond : mx .array ,
@@ -1049,7 +1043,7 @@ def sample(
10491043 assert text .shape [0 ] == batch
10501044
10511045 if exists (text ):
1052- text_lens = (text != - 1 ).sum (dim = - 1 )
1046+ text_lens = (text != - 1 ).sum (axis = - 1 )
10531047 lens = mx .maximum (
10541048 text_lens , lens
10551049 ) # make sure lengths are at least those of the text characters
@@ -1070,15 +1064,15 @@ def sample(
10701064 duration = mx .maximum (
10711065 lens + 1 , duration
10721066 ) # just add one token so something is generated
1073- duration = duration . clamp ( max = max_duration )
1067+ duration = mx . minimum ( duration , max_duration )
10741068
10751069 assert duration .shape [0 ] == batch
10761070
1077- max_duration = duration .amax ()
1071+ max_duration = duration .max (). item ()
10781072
1079- cond = mx .pad (cond , (0 , 0 , 0 , max_duration - cond_seq_len ), value = 0. 0 )
1073+ cond = mx .pad (cond , [ (0 , 0 ), ( 0 , max_duration - cond_seq_len ), ( 0 , 0 )], constant_values = 0 )
10801074 cond_mask = mx .pad (
1081- cond_mask , (0 , max_duration - cond_mask .shape [- 1 ]), value = False
1075+ cond_mask , [ (0 , 0 ), ( 0 , max_duration - cond_mask .shape [- 1 ])], constant_values = False
10821076 )
10831077 cond_mask = rearrange (cond_mask , "... -> ... 1" )
10841078
@@ -1088,7 +1082,7 @@ def sample(
10881082
10891083 def fn (t , x ):
10901084 # at each step, conditioning is fixed
1091-
1085+
10921086 step_cond = mx .where (cond_mask , cond , mx .zeros_like (cond ))
10931087
10941088 # predict flow
@@ -1100,7 +1094,7 @@ def fn(t, x):
11001094 y0 = mx .random .normal (cond .shape )
11011095 t = mx .linspace (0 , 1 , steps )
11021096
1103- trajectory = self .odeint (fn , y0 , t , ** self . odeint_kwargs )
1097+ trajectory = self .odeint (fn , y0 , t )
11041098 sampled = trajectory [- 1 ]
11051099
11061100 out = sampled
0 commit comments