2828from compiler_opt .es import blackbox_optimizers
2929from compiler_opt .distributed import buffered_scheduler
3030from compiler_opt .rl import compilation_runner
31+ from compiler_opt import baseline_cache
3132
3233
3334def _extract_results (futures : list [concurrent .futures .Future ]) -> list [Any ]:
@@ -51,6 +52,8 @@ def __init__(self, *, train_corpus: corpus.Corpus,
5152 estimator_type : blackbox_optimizers .EstimatorType ):
5253 self ._train_corpus = train_corpus
5354 self ._estimator_type = estimator_type
55+ self ._baseline_cache = baseline_cache .BaselineCache (
56+ get_key = lambda x : x .name )
5457
5558 @abc .abstractmethod
5659 def get_results (
@@ -73,12 +76,17 @@ def __init__(self,
7376 num_ir_repeats_within_worker : int = 1 ,
7477 ** kwargs ):
7578 super ().__init__ (** kwargs )
76- self ._samples : list [list [corpus .LoadedModuleSpec ]] = []
7779 self ._total_num_perturbations = total_num_perturbations
7880 self ._num_ir_repeats_within_worker = num_ir_repeats_within_worker
79- self ._baselines : list [ float | None ] | None = None
81+ self ._reset ()
8082
81- def _load_samples (self ) -> None :
83+ def _reset (self ):
84+ # TODO: this object is currently supposed to respect a state transition
85+ # and that makes it less maintainable than if not.
86+ self ._samples = None
87+ self ._baselines = None
88+
89+ def load_samples (self ) -> None :
8290 """Samples and loads modules if not already done.
8391
8492 Ensures self._samples contains the expected number of loaded samples.
@@ -89,6 +97,7 @@ def _load_samples(self) -> None:
8997 """
9098 if self ._samples :
9199 raise RuntimeError ('Samples have already been loaded.' )
100+ self ._samples = []
92101 for _ in range (self ._total_num_perturbations ):
93102 samples = self ._train_corpus .sample (self ._num_ir_repeats_within_worker )
94103 loaded_samples = [
@@ -108,15 +117,15 @@ def _load_samples(self) -> None:
108117 if len (self ._samples ) != expected_count :
109118 raise RuntimeError ('Some samples could not be loaded correctly.' )
110119
111- def _launch_compilation_workers (self ,
112- pool : FixedWorkerPool ,
113- perturbations : list [ bytes ] | None = None
114- ) -> list [concurrent . futures . Future ]:
115- if self . _samples is None :
116- raise RuntimeError ( 'Loaded samples are not available.' )
120+ def _launch_compilation_workers (
121+ self ,
122+ pool : FixedWorkerPool ,
123+ samples : list [list [ corpus . LoadedModuleSpec ]],
124+ perturbations : list [ bytes ] | None = None
125+ ) -> list [ concurrent . futures . Future ]:
117126 if perturbations is None :
118- perturbations = [None ] * len (self . _samples )
119- compile_args = zip (perturbations , self . _samples )
127+ perturbations = [None ] * len (samples )
128+ compile_args = zip (perturbations , samples )
120129 _ , futures = buffered_scheduler .schedule_on_worker_pool (
121130 action = lambda w , args : w .compile (policy = args [0 ], modules = args [1 ]),
122131 jobs = compile_args ,
@@ -130,24 +139,43 @@ def _launch_compilation_workers(self,
130139 not_done , return_when = concurrent .futures .FIRST_COMPLETED )
131140 return futures
132141
142+ def ensure_baselines (self , pool ):
143+ if self ._samples is None :
144+ raise RuntimeError ('Loaded samples are not available.' )
145+ # flatten the samples.
146+ flat_samples = [item for sublist in self ._samples for item in sublist ]
147+
148+ def _get_scores (some_list ):
149+ futures = self ._launch_compilation_workers (pool , [[x ] for x in some_list ])
150+ return _extract_results (futures )
151+
152+ baselines = self ._baseline_cache .get_score (flat_samples , _get_scores )
153+
154+ # TODO: the business of accummulating compilation results is now shared
155+ # with the worker.
156+ def sum_or_none (lst ):
157+ return sum (lst ) if all (x is not None for x in lst ) else None
158+
159+ self ._baselines = [
160+ sum_or_none (baselines [i :i + len (self ._samples [i ])])
161+ for i in range (len (self ._samples ))
162+ ]
163+
133164 def get_results (
134165 self , pool : FixedWorkerPool ,
135166 perturbations : list [bytes ]) -> list [concurrent .futures .Future ]:
136- # We should have _samples by now.
137167 if not self ._samples :
138- raise RuntimeError ('Loaded samples are not available.' )
139- return self ._launch_compilation_workers (pool , perturbations )
168+ self .load_samples ()
169+ self .ensure_baselines (pool )
170+ return self ._launch_compilation_workers (pool , self ._samples , perturbations )
140171
141172 def set_baseline (self , pool : FixedWorkerPool ) -> None :
142- if self ._baselines is not None :
143- raise RuntimeError ('The baseline has already been set.' )
144- self ._load_samples ()
145- results_futures = self ._launch_compilation_workers (pool )
146- self ._baselines = _extract_results (results_futures )
173+ pass
147174
148175 def get_rewards (
149176 self ,
150177 results_futures : list [concurrent .futures .Future ]) -> list [float | None ]:
178+ # we need a pool to get the baselines, so we should have gotten them already
151179 if self ._baselines is None :
152180 raise RuntimeError ('The baseline has not been set.' )
153181
@@ -165,6 +193,7 @@ def get_rewards(
165193 else :
166194 rewards .append (
167195 compilation_runner .calculate_reward (policy_result , baseline ))
196+ self ._reset ()
168197 return rewards
169198
170199
0 commit comments