|
1 | | -import shutil |
2 | | -from copy import deepcopy |
3 | 1 | from dataclasses import dataclass |
4 | | -from pathlib import Path |
5 | 2 | from typing import Optional, Union |
6 | 3 |
|
7 | 4 | from simple_parsing import ArgumentParser, ConflictResolution |
|
17 | 14 | TrackstarConfig, |
18 | 15 | ) |
19 | 16 | from .hessians.hessian_approximations import approximate_hessians |
20 | | -from .process_grads import mix_preconditioners |
21 | 17 | from .query.query_index import query |
22 | 18 | from .reduce import reduce |
23 | 19 | from .score.score import score_dataset |
24 | | - |
25 | | - |
26 | | -def validate_run_path(index_cfg: IndexConfig): |
27 | | - """Validate the run path.""" |
28 | | - if index_cfg.distributed.rank != 0: |
29 | | - return |
30 | | - |
31 | | - for path in [Path(index_cfg.run_path), Path(index_cfg.partial_run_path)]: |
32 | | - if not path.exists(): |
33 | | - continue |
34 | | - |
35 | | - if index_cfg.overwrite: |
36 | | - shutil.rmtree(path) |
37 | | - else: |
38 | | - raise FileExistsError( |
39 | | - f"Run path {path} already exists. Use --overwrite to overwrite it." |
40 | | - ) |
| 20 | +from .trackstar import trackstar |
| 21 | +from .utils.worker_utils import validate_run_path |
41 | 22 |
|
42 | 23 |
|
43 | 24 | @dataclass |
@@ -150,70 +131,17 @@ class Trackstar: |
150 | 131 |
|
151 | 132 | index_cfg: IndexConfig |
152 | 133 |
|
153 | | - trackstar_cfg: TrackstarConfig |
154 | | - |
155 | 134 | score_cfg: ScoreConfig |
156 | 135 |
|
157 | 136 | preprocess_cfg: PreprocessConfig |
158 | 137 |
|
| 138 | + trackstar_cfg: TrackstarConfig |
| 139 | + |
159 | 140 | def execute(self): |
160 | | - """Run the full trackstar pipeline: preconditioners -> mix -> build -> score.""" |
161 | | - run_path = self.index_cfg.run_path |
162 | | - value_precond_path = f"{run_path}/value_preconditioner" |
163 | | - query_precond_path = f"{run_path}/query_preconditioner" |
164 | | - mixed_precond_path = f"{run_path}/mixed_preconditioner" |
165 | | - query_path = f"{run_path}/query" |
166 | | - scores_path = f"{run_path}/scores" |
167 | | - |
168 | | - # Step 1: Compute normalizers and preconditioners on value dataset |
169 | | - print("Step 1/5: Computing normalizers and preconditioners on value dataset...") |
170 | | - value_precond_cfg = deepcopy(self.index_cfg) |
171 | | - value_precond_cfg.run_path = value_precond_path |
172 | | - value_precond_cfg.skip_index = True |
173 | | - value_precond_cfg.skip_preconditioners = False |
174 | | - validate_run_path(value_precond_cfg) |
175 | | - build(value_precond_cfg, self.preprocess_cfg) |
176 | | - |
177 | | - # Step 2: Compute normalizers and preconditioners on query dataset |
178 | | - print("Step 2/5: Computing normalizers and preconditioners on query dataset...") |
179 | | - query_precond_cfg = deepcopy(self.index_cfg) |
180 | | - query_precond_cfg.run_path = query_precond_path |
181 | | - query_precond_cfg.data = self.trackstar_cfg.query |
182 | | - query_precond_cfg.skip_index = True |
183 | | - query_precond_cfg.skip_preconditioners = False |
184 | | - validate_run_path(query_precond_cfg) |
185 | | - build(query_precond_cfg, self.preprocess_cfg) |
186 | | - |
187 | | - # Step 3: Mix query and value preconditioners |
188 | | - print("Step 3/5: Mixing preconditioners...") |
189 | | - mix_preconditioners( |
190 | | - query_path=query_precond_path, |
191 | | - index_path=value_precond_path, |
192 | | - output_path=mixed_precond_path, |
193 | | - mixing_coefficient=self.trackstar_cfg.mixing_coefficient, |
| 141 | + trackstar( |
| 142 | + self.index_cfg, self.score_cfg, self.preprocess_cfg, self.trackstar_cfg |
194 | 143 | ) |
195 | 144 |
|
196 | | - # Step 4: Build per-item query gradient index |
197 | | - print("Step 4/5: Building query gradient index...") |
198 | | - query_cfg = deepcopy(self.index_cfg) |
199 | | - query_cfg.run_path = query_path |
200 | | - query_cfg.data = self.trackstar_cfg.query |
201 | | - query_cfg.processor_path = query_precond_path |
202 | | - query_cfg.skip_preconditioners = True |
203 | | - validate_run_path(query_cfg) |
204 | | - build(query_cfg, self.preprocess_cfg) |
205 | | - |
206 | | - # Step 5: Score value dataset against query using mixed preconditioner |
207 | | - print("Step 5/5: Scoring value dataset...") |
208 | | - score_index_cfg = deepcopy(self.index_cfg) |
209 | | - score_index_cfg.run_path = scores_path |
210 | | - score_index_cfg.processor_path = value_precond_path |
211 | | - score_index_cfg.skip_preconditioners = True |
212 | | - self.score_cfg.query_path = query_path |
213 | | - self.preprocess_cfg.preconditioner_path = mixed_precond_path |
214 | | - validate_run_path(score_index_cfg) |
215 | | - score_dataset(score_index_cfg, self.score_cfg, self.preprocess_cfg) |
216 | | - |
217 | 145 |
|
218 | 146 | @dataclass |
219 | 147 | class Main: |
|
0 commit comments