Skip to content

Commit 0b3347f

Browse files
authored
Fix for eventloop in AutoPDL optimizer (#1390)
Signed-off-by: Mandana Vaziri <[email protected]>
1 parent f0285d5 commit 0b3347f

File tree

5 files changed

+49
-1
lines changed

5 files changed

+49
-1
lines changed

examples/optimizer/optimize.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from pdl.optimize.optimizer_evaluator import OptimizerEvaluator
1717
from pdl.optimize.pdl_evaluator import PdlEvaluator
1818
from pdl.optimize.pdl_optimizer import PDLOptimizer
19+
from pdl.pdl import InterpreterConfig
20+
from pdl.pdl_scheduler import create_event_loop_thread
21+
22+
_LOOP = create_event_loop_thread()
23+
1924

2025
if __name__ == "__main__":
2126
parser = argparse.ArgumentParser(
@@ -138,13 +143,17 @@
138143
print(f"Unknown dataset: {config.dataset}")
139144
sys.exit(1)
140145

146+
PDLConfig = InterpreterConfig()
147+
PDLConfig["event_loop"] = _LOOP
148+
141149
# Create optimizer instance
142150
optimizer = PDLOptimizer(
143151
dataset=dataset,
144152
trial_thread=TrialThread,
145153
yield_output=args.yield_output,
146154
experiment_path=args.experiments_path,
147155
config=config,
156+
pdl_config=PDLConfig,
148157
)
149158

150159
# Execute the appropriate command

src/pdl/optimize/optimizer_evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
timeout: int,
2828
yield_output: bool,
2929
config: OptimizationConfig,
30+
pdl_config: InterpreterConfig,
3031
cwd: Path,
3132
answer_key: str = "answer",
3233
) -> None:
@@ -40,6 +41,7 @@ def __init__(
4041
self.yield_output = yield_output
4142
self.answer_key = answer_key
4243
self.config = config
44+
self.pdl_config = pdl_config
4345
self.cwd = cwd
4446

4547
def get_scope(self) -> ScopeType:
@@ -75,6 +77,7 @@ def run( # type: ignore # noqa: C901
7577
yield_result=self.yield_output,
7678
yield_background=self.yield_output,
7779
cwd=self.cwd,
80+
event_loop=self.pdl_config["event_loop"], # type: ignore
7881
)
7982
scope = self.get_scope()
8083

src/pdl/optimize/pdl_evaluator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ def score(self, document: str, ground_truth: Any) -> float:
4949
args:
5050
document: ${{ document }}
5151
ground_truth: ${{ ground_truth }}"""
52-
result = exec_str(prog=prog, scope=scope, output="result")
52+
53+
result = exec_str(
54+
prog=prog, config=self.pdl_config, scope=scope, output="result"
55+
)
5356

5457
if isinstance(result, str):
5558
result = result.strip()

src/pdl/optimize/pdl_optimizer.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
from pdl.optimize.optimizer_evaluator import OptimizerEvaluator
2727
from pdl.optimize.pdl_evaluator import PdlEvaluator
2828
from pdl.optimize.util import CandidateResult, TrialOutput, console, execute_threads
29+
from pdl.pdl import InterpreterConfig
2930
from pdl.pdl_ast import AdvancedBlockType, DataBlock, Program
3031
from pdl.pdl_dumper import dump_program_exclude_internals
32+
from pdl.pdl_scheduler import create_event_loop_thread
3133

3234
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
3335
FORMAT = "%(message)s"
@@ -36,6 +38,8 @@
3638
logger.addHandler(RichHandler())
3739
rng = default_rng()
3840

41+
_LOOP = create_event_loop_thread()
42+
3943

4044
def resave_pdl(input_path: Path, output_path: Path, state: dict) -> int:
4145
with (input_path.open(encoding="utf-8") as pdl,):
@@ -73,6 +77,7 @@ def __init__(
7377
self,
7478
dataset: DatasetDict,
7579
config: OptimizationConfig,
80+
pdl_config: InterpreterConfig,
7681
trial_thread: type[OptimizerEvaluator],
7782
yield_output: bool,
7883
experiment_path: Path,
@@ -81,6 +86,7 @@ def __init__(
8186
self.yield_output = yield_output
8287

8388
self.config = config
89+
self.pdl_config = pdl_config
8490
self.pdl_path = Path(config.pdl_path)
8591
self.parallelism = config.parallelism
8692
self.num_demonstrations = config.num_demonstrations
@@ -631,6 +637,7 @@ def evaluate(
631637
timeout=self.timeout,
632638
yield_output=self.yield_output,
633639
config=self.config,
640+
pdl_config=self.pdl_config,
634641
cwd=pdl_file_parent,
635642
),
636643
)
@@ -827,13 +834,17 @@ def run_optimizer() -> int:
827834
print(f"Unknown dataset: {config.dataset}")
828835
sys.exit(1)
829836

837+
pdl_config = InterpreterConfig()
838+
pdl_config["event_loop"] = _LOOP
839+
830840
# Create optimizer instance
831841
optimizer = PDLOptimizer(
832842
dataset=dataset,
833843
trial_thread=PdlEvaluator,
834844
yield_output=args.yield_output,
835845
experiment_path=args.experiments_path,
836846
config=config,
847+
pdl_config=pdl_config,
837848
)
838849
optimizer.run()
839850
return 0

tests/test_optimizer.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111
from examples.optimizer.mbpp_evaluator import MBPPEvaluator
1212
from pdl.optimize.config_parser import OptimizationConfig
1313
from pdl.optimize.pdl_optimizer import PDLOptimizer
14+
from pdl.pdl import InterpreterConfig
15+
from pdl.pdl_scheduler import create_event_loop_thread
16+
17+
_LOOP = create_event_loop_thread()
1418

1519

1620
def test_gsm8k_cot():
@@ -351,12 +355,17 @@ def test_gsm8k_cot():
351355
"test": Dataset.from_list(gsm8k["test"]),
352356
},
353357
)
358+
359+
pdl_config = InterpreterConfig()
360+
pdl_config["event_loop"] = _LOOP
361+
354362
optim = PDLOptimizer(
355363
dataset=gsm8k,
356364
trial_thread=Gsm8kEvaluator,
357365
yield_output=True,
358366
experiment_path=Path("test_experiments"),
359367
config=config,
368+
pdl_config=pdl_config,
360369
)
361370

362371
result = optim.run()
@@ -712,12 +721,17 @@ def run_optimizer_gsm8k(pattern, num_demonstrations=0):
712721
"test": Dataset.from_list(gsm8k["test"]),
713722
},
714723
)
724+
725+
pdl_config = InterpreterConfig()
726+
pdl_config["event_loop"] = _LOOP
727+
715728
optim = PDLOptimizer(
716729
dataset=gsm8k,
717730
trial_thread=Gsm8kEvaluator,
718731
yield_output=True,
719732
experiment_path=Path("test_experiments"),
720733
config=config,
734+
pdl_config=pdl_config,
721735
)
722736

723737
result = optim.run()
@@ -1078,12 +1092,16 @@ def run_optimizer_fever(pattern, num_demonstrations=0):
10781092
},
10791093
)
10801094

1095+
pdl_config = InterpreterConfig()
1096+
pdl_config["event_loop"] = _LOOP
1097+
10811098
optim = PDLOptimizer(
10821099
dataset=fever, # pyright: ignore
10831100
trial_thread=FEVEREvaluator,
10841101
yield_output=True,
10851102
experiment_path=Path("test_experiments"),
10861103
config=config,
1104+
pdl_config=pdl_config,
10871105
)
10881106

10891107
result = optim.run()
@@ -1123,12 +1141,16 @@ def run_optimizer_mbpp(pattern, num_demonstrations=0):
11231141
"../prompt-declaration-language-merge/var/mbpp_trajectified",
11241142
)
11251143

1144+
pdl_config = InterpreterConfig()
1145+
pdl_config["event_loop"] = _LOOP
1146+
11261147
optim = PDLOptimizer(
11271148
dataset=mbpp_dataset, # pyright: ignore
11281149
trial_thread=MBPPEvaluator,
11291150
yield_output=True,
11301151
experiment_path=Path("test_experiments"),
11311152
config=config,
1153+
pdl_config=pdl_config,
11321154
)
11331155

11341156
result = optim.run()

0 commit comments

Comments
 (0)