Skip to content

Commit d027b39

Browse files
authored
Merge pull request #100 from UT-Austin-RPL/scalar_loss
improve MixtureOfDatasets anneal
2 parents d5ec396 + 313b770 commit d027b39

1 file changed

Lines changed: 33 additions & 15 deletions

File tree

amago/loading.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -327,8 +327,17 @@ class MixtureOfDatasets(RLDataset):
327327
Args:
328328
datasets: A list of :py:class:`~amago.loading.RLDataset` objects.
329329
sampling_weights: Probability of sampling from each dataset. Must sum to 1.
330-
smooth_sudden_starts: When a dataset becomes ready for training mid-way through training,
331-
anneal its sampling_weight from 0 --> assigned ``sampling_weights[i]`` over ``smooth_sudden_starts`` epochs.
330+
These are the *final* target weights after annealing completes.
331+
initial_sampling_weights: Optional starting weights before annealing. When
332+
provided, each dataset anneals from ``initial_sampling_weights[i]`` to
333+
``sampling_weights[i]`` over ``smooth_sudden_starts`` epochs (starting
334+
from the epoch the dataset becomes ready for training). When ``None``
335+
(default), datasets that are ready at the start of training jump
336+
straight to their target weight, while datasets that become ready later
337+
anneal up from 0.
338+
smooth_sudden_starts: Number of epochs over which to anneal each dataset's
339+
sampling weight from its initial value to its target value, starting
340+
from the epoch the dataset becomes ready for training.
332341
dset_name: The name of the dataset. Used to identify the dataset in logging metrics. Defaults to the class name.
333342
334343
Note:
@@ -345,6 +354,7 @@ def __init__(
345354
self,
346355
datasets: list[RLDataset],
347356
sampling_weights: list[float],
357+
initial_sampling_weights: Optional[list[float]] = None,
348358
smooth_sudden_starts: Optional[int] = None,
349359
dset_name: Optional[str] = None,
350360
):
@@ -357,8 +367,15 @@ def __init__(
357367
raise ValueError(
358368
f"{self.__class__.__name__} requires the sum of sampling weights to be 1.0"
359369
)
370+
if initial_sampling_weights is not None and len(
371+
initial_sampling_weights
372+
) != len(datasets):
373+
raise ValueError(
374+
f"{self.__class__.__name__} requires the same number of initial_sampling_weights ({len(initial_sampling_weights)}) as datasets ({len(datasets)})"
375+
)
360376
self.all_datasets = datasets
361377
self.sampling_weights = sampling_weights
378+
self.initial_sampling_weights = initial_sampling_weights
362379
self.smooth_sudden_starts = smooth_sudden_starts
363380

364381
# find up to 1 dataset that determines where Experiment will write new trajectories
@@ -378,10 +395,13 @@ def configure_from_experiment(self, experiment):
378395

379396
# with datasets configured, initialize our annealing schedule
380397
self._dsets_status = []
381-
for d, w in zip(self.all_datasets, self.sampling_weights):
382-
# if the dataset is available from the start of training, turn off its schedule
383-
# setting initial_weight to the final weight.
384-
initial_weight = w if d.ready_for_training else 0
398+
for i, (d, w) in enumerate(zip(self.all_datasets, self.sampling_weights)):
399+
if self.initial_sampling_weights is not None:
400+
initial_weight = self.initial_sampling_weights[i]
401+
else:
402+
# legacy behavior: ready datasets start at their final weight,
403+
# not-yet-ready datasets anneal up from 0
404+
initial_weight = w if d.ready_for_training else 0
385405
self._dsets_status.append(
386406
_DatasetStatus(
387407
dataset=d,
@@ -416,22 +436,17 @@ def update_dset_weights(self, epoch: int):
416436
for status in self._dsets_status:
417437
if status.epoch_ready is not None:
418438
if self.smooth_sudden_starts is None:
419-
# active datasets jump right to their final weight
420439
current_weight = status.final_weight
421440
else:
422-
# linear schedule for smooth_sudden_starts epochs
423-
# after the dataset is first discovered to be ready
424441
m = (
425442
status.final_weight - status.initial_weight
426443
) / self.smooth_sudden_starts
427444
x = epoch - status.epoch_ready + 1
428-
current_weight = min(
429-
m * x + status.initial_weight,
430-
status.final_weight,
431-
)
445+
lo = min(status.initial_weight, status.final_weight)
446+
hi = max(status.initial_weight, status.final_weight)
447+
current_weight = np.clip(m * x + status.initial_weight, lo, hi)
432448
self._available_datasets.append((status.dataset, current_weight))
433449
else:
434-
# if dataset is not ready for training, set weight to 0
435450
self._available_datasets.append((status.dataset, 0.0))
436451

437452
def on_end_of_collection(self, experiment) -> dict[str, Any]:
@@ -455,9 +470,12 @@ def on_end_of_collection(self, experiment) -> dict[str, Any]:
455470

456471
# find datasets that are now ready for training, and adjust their sampling weight
457472
self.update_dset_weights(experiment.epoch)
458-
# log sampling weight logic to wandb
473+
# log raw and normalized sampling weights to wandb
474+
total_weight = sum(w for _, w in self._available_datasets)
459475
for dset, active_weight in self._available_datasets:
460476
metrics[f"{dset.dset_name} Current Sample Weight"] = active_weight
477+
normalized = active_weight / total_weight if total_weight > 0 else 0.0
478+
metrics[f"{dset.dset_name} Normalized Sample Weight"] = normalized
461479
return metrics
462480

463481
def delete(self):

0 commit comments

Comments
 (0)