diff --git a/compiler_gym/datasets/datasets.py b/compiler_gym/datasets/datasets.py index 05ecf0e45..0bf16e573 100644 --- a/compiler_gym/datasets/datasets.py +++ b/compiler_gym/datasets/datasets.py @@ -234,6 +234,35 @@ def benchmarks(self, with_deprecated: bool = False) -> Iterable[Benchmark]: (d.benchmarks() for d in self.datasets(with_deprecated=with_deprecated)) ) + def benchmarks_from_distrib( + self, + datasets: List[str] = None, + weights: List[float] = None, + dataset_size: int = -1, + ) -> Iterable[Benchmark]: + """ + Foivos WIP. + Select a dataset to sample from with some weight probability. + If weights is None, select among `datasets` uniformly. + """ + datasets = datasets or list(self._datasets.values()) + if weights is None: + weights = [1 / len(datasets)] * len(datasets) + if len(weights) != len(datasets): + raise ValueError( + "Mismatch between datasets size: {} and sampling weights length: {}!".format( + len(datasets), len(weights) + ) + ) + idx = 0 + while dataset_size == -1 or idx < dataset_size: + sampled_key = np.random.choice(datasets, p=weights) + if sampled_key not in self._datasets: + raise LookupError(f"Dataset not found: {sampled_key}") + dataset = self._datasets[sampled_key] + return round_robin_iterables((dataset,)) + return + def benchmark_uris(self, with_deprecated: bool = False) -> Iterable[str]: """Enumerate the (possibly infinite) benchmark URIs.