@@ -54,8 +54,7 @@ class TripleCartesianProd(Data):
5454
5555 Args:
5656 X_train: A tuple of two NumPy arrays. The first element has the shape (`N1`,
57- `dim1`), and the second element has the shape (`N2`, `dim2`). The mini-batch
58- is only applied to `N1`.
57+ `dim1`), and the second element has the shape (`N2`, `dim2`).
5958 y_train: A NumPy array of shape (`N1`, `N2`).
6059 """
6160
@@ -71,16 +70,24 @@ def __init__(self, X_train, y_train, X_test, y_test):
7170 self .train_x , self .train_y = X_train , y_train
7271 self .test_x , self .test_y = X_test , y_test
7372
74- self .train_sampler = BatchSampler (len (X_train [0 ]), shuffle = True )
73+ self .branch_sampler = BatchSampler (len (X_train [0 ]), shuffle = True )
74+ self .trunk_sampler = BatchSampler (len (X_train [1 ]), shuffle = True )
7575
7676 def losses (self , targets , outputs , loss_fn , inputs , model , aux = None ):
7777 return loss_fn (targets , outputs )
7878
7979 def train_next_batch (self , batch_size = None ):
8080 if batch_size is None :
8181 return self .train_x , self .train_y
82- indices = self .train_sampler .get_next (batch_size )
83- return (self .train_x [0 ][indices ], self .train_x [1 ]), self .train_y [indices ]
82+ if not isinstance (batch_size , (tuple , list )):
83+ indices = self .branch_sampler .get_next (batch_size )
84+ return (self .train_x [0 ][indices ], self .train_x [1 ]), self .train_y [indices ]
85+ indices_branch = self .branch_sampler .get_next (batch_size [0 ])
86+ indices_trunk = self .trunk_sampler .get_next (batch_size [1 ])
87+ return (
88+ self .train_x [0 ][indices_branch ],
89+ self .train_x [1 ][indices_trunk ],
90+ ), self .train_y [indices_branch , indices_trunk ]
8491
8592 def test (self ):
8693 return self .test_x , self .test_y
0 commit comments