@@ -327,8 +327,17 @@ class MixtureOfDatasets(RLDataset):
327327 Args:
328328 datasets: A list of :py:class:`~amago.loading.RLDataset` objects.
329329 sampling_weights: Probability of sampling from each dataset. Must sum to 1.
330- smooth_sudden_starts: When a dataset becomes ready for training mid-way through training,
331- anneal its sampling_weight from 0 --> assigned ``sampling_weights[i]`` over ``smooth_sudden_starts`` epochs.
330+ These are the *final* target weights after annealing completes.
331+ initial_sampling_weights: Optional starting weights before annealing. When
332+ provided, each dataset anneals from ``initial_sampling_weights[i]`` to
333+ ``sampling_weights[i]`` over ``smooth_sudden_starts`` epochs (starting
334+ from the epoch the dataset becomes ready for training). When ``None``
335+ (default), datasets that are ready at the start of training jump
336+ straight to their target weight, while datasets that become ready later
337+ anneal up from 0.
338+ smooth_sudden_starts: Number of epochs over which to anneal each dataset's
339+ sampling weight from its initial value to its target value, starting
340+ from the epoch the dataset becomes ready for training.
332341 dset_name: The name of the dataset. Used to identify the dataset in logging metrics. Defaults to the class name.
333342
334343 Note:
@@ -345,6 +354,7 @@ def __init__(
345354 self ,
346355 datasets : list [RLDataset ],
347356 sampling_weights : list [float ],
357+ initial_sampling_weights : Optional [list [float ]] = None ,
348358 smooth_sudden_starts : Optional [int ] = None ,
349359 dset_name : Optional [str ] = None ,
350360 ):
@@ -357,8 +367,15 @@ def __init__(
357367 raise ValueError (
358368 f"{ self .__class__ .__name__ } requires the sum of sampling weights to be 1.0"
359369 )
370+ if initial_sampling_weights is not None and len (
371+ initial_sampling_weights
372+ ) != len (datasets ):
373+ raise ValueError (
374+ f"{ self .__class__ .__name__ } requires the same number of initial_sampling_weights ({ len (initial_sampling_weights )} ) as datasets ({ len (datasets )} )"
375+ )
360376 self .all_datasets = datasets
361377 self .sampling_weights = sampling_weights
378+ self .initial_sampling_weights = initial_sampling_weights
362379 self .smooth_sudden_starts = smooth_sudden_starts
363380
364381 # find up to 1 dataset that determines where Experiment will write new trajectories
@@ -378,10 +395,13 @@ def configure_from_experiment(self, experiment):
378395
379396 # with datasets configured, initialize our annealing schedule
380397 self ._dsets_status = []
381- for d , w in zip (self .all_datasets , self .sampling_weights ):
382- # if the dataset is available from the start of training, turn off its schedule
383- # setting initial_weight to the final weight.
384- initial_weight = w if d .ready_for_training else 0
398+ for i , (d , w ) in enumerate (zip (self .all_datasets , self .sampling_weights )):
399+ if self .initial_sampling_weights is not None :
400+ initial_weight = self .initial_sampling_weights [i ]
401+ else :
402+ # legacy behavior: ready datasets start at their final weight,
403+ # not-yet-ready datasets anneal up from 0
404+ initial_weight = w if d .ready_for_training else 0
385405 self ._dsets_status .append (
386406 _DatasetStatus (
387407 dataset = d ,
@@ -416,22 +436,17 @@ def update_dset_weights(self, epoch: int):
416436 for status in self ._dsets_status :
417437 if status .epoch_ready is not None :
418438 if self .smooth_sudden_starts is None :
419- # active datasets jump right to their final weight
420439 current_weight = status .final_weight
421440 else :
422- # linear schedule for smooth_sudden_starts epochs
423- # after the dataset is first discovered to be ready
424441 m = (
425442 status .final_weight - status .initial_weight
426443 ) / self .smooth_sudden_starts
427444 x = epoch - status .epoch_ready + 1
428- current_weight = min (
429- m * x + status .initial_weight ,
430- status .final_weight ,
431- )
445+ lo = min (status .initial_weight , status .final_weight )
446+ hi = max (status .initial_weight , status .final_weight )
447+ current_weight = np .clip (m * x + status .initial_weight , lo , hi )
432448 self ._available_datasets .append ((status .dataset , current_weight ))
433449 else :
434- # if dataset is not ready for training, set weight to 0
435450 self ._available_datasets .append ((status .dataset , 0.0 ))
436451
437452 def on_end_of_collection (self , experiment ) -> dict [str , Any ]:
@@ -455,9 +470,12 @@ def on_end_of_collection(self, experiment) -> dict[str, Any]:
455470
456471 # find datasets that are now ready for training, and adjust their sampling weight
457472 self .update_dset_weights (experiment .epoch )
458- # log sampling weight logic to wandb
473+ # log raw and normalized sampling weights to wandb
474+ total_weight = sum (w for _ , w in self ._available_datasets )
459475 for dset , active_weight in self ._available_datasets :
460476 metrics [f"{ dset .dset_name } Current Sample Weight" ] = active_weight
477+ normalized = active_weight / total_weight if total_weight > 0 else 0.0
478+ metrics [f"{ dset .dset_name } Normalized Sample Weight" ] = normalized
461479 return metrics
462480
463481 def delete (self ):
0 commit comments