@@ -168,9 +168,7 @@ def __init__(
168168 n_steps_input : int = 1 ,
169169 n_steps_output : int = 1 ,
170170 stride : int = 1 ,
171- # TODO: support for passing data from dict
172- input_channel_idxs : tuple [int , ...] | None = None ,
173- output_channel_idxs : tuple [int , ...] | None = None ,
171+ channel_idxs : tuple [int , ...] | None = None ,
174172 batch_size : int = 4 ,
175173 dtype : torch .dtype = torch .float32 ,
176174 ftype : str = "torch" ,
@@ -205,8 +203,7 @@ def __init__(
205203 n_steps_input = n_steps_input ,
206204 n_steps_output = n_steps_output ,
207205 stride = stride ,
208- input_channel_idxs = input_channel_idxs ,
209- output_channel_idxs = output_channel_idxs ,
206+ channel_idxs = channel_idxs ,
210207 autoencoder_mode = self .autoencoder_mode ,
211208 full_trajectory_mode = full_trajectory_mode ,
212209 dtype = dtype ,
@@ -237,8 +234,7 @@ def __init__(
237234 n_steps_input = n_steps_input ,
238235 n_steps_output = n_steps_output ,
239236 stride = stride ,
240- input_channel_idxs = input_channel_idxs ,
241- output_channel_idxs = output_channel_idxs ,
237+ channel_idxs = channel_idxs ,
242238 autoencoder_mode = self .autoencoder_mode ,
243239 full_trajectory_mode = full_trajectory_mode ,
244240 dtype = dtype ,
@@ -254,8 +250,7 @@ def __init__(
254250 n_steps_input = n_steps_input ,
255251 n_steps_output = n_steps_output ,
256252 stride = stride ,
257- input_channel_idxs = input_channel_idxs ,
258- output_channel_idxs = output_channel_idxs ,
253+ channel_idxs = channel_idxs ,
259254 autoencoder_mode = self .autoencoder_mode ,
260255 full_trajectory_mode = full_trajectory_mode ,
261256 dtype = dtype ,
@@ -275,8 +270,7 @@ def __init__(
275270 n_steps_input = n_steps_input ,
276271 n_steps_output = n_steps_output ,
277272 stride = stride ,
278- input_channel_idxs = input_channel_idxs ,
279- output_channel_idxs = output_channel_idxs ,
273+ channel_idxs = channel_idxs ,
280274 full_trajectory_mode = True ,
281275 dtype = dtype ,
282276 verbose = self .verbose ,
@@ -291,8 +285,7 @@ def __init__(
291285 n_steps_input = n_steps_input ,
292286 n_steps_output = n_steps_output ,
293287 stride = stride ,
294- input_channel_idxs = input_channel_idxs ,
295- output_channel_idxs = output_channel_idxs ,
288+ channel_idxs = channel_idxs ,
296289 full_trajectory_mode = True ,
297290 dtype = dtype ,
298291 verbose = self .verbose ,
0 commit comments