Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions examples/optimizer/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from pdl.optimize.optimizer_evaluator import OptimizerEvaluator
from pdl.optimize.pdl_evaluator import PdlEvaluator
from pdl.optimize.pdl_optimizer import PDLOptimizer
from pdl.pdl import InterpreterConfig
from pdl.pdl_scheduler import create_event_loop_thread

_LOOP = create_event_loop_thread()


if __name__ == "__main__":
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -138,13 +143,17 @@
print(f"Unknown dataset: {config.dataset}")
sys.exit(1)

PDLConfig = InterpreterConfig()
PDLConfig["event_loop"] = _LOOP

# Create optimizer instance
optimizer = PDLOptimizer(
dataset=dataset,
trial_thread=TrialThread,
yield_output=args.yield_output,
experiment_path=args.experiments_path,
config=config,
pdl_config=PDLConfig,
)

# Execute the appropriate command
Expand Down
3 changes: 3 additions & 0 deletions src/pdl/optimize/optimizer_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(
timeout: int,
yield_output: bool,
config: OptimizationConfig,
pdl_config: InterpreterConfig,
cwd: Path,
answer_key: str = "answer",
) -> None:
Expand All @@ -40,6 +41,7 @@ def __init__(
self.yield_output = yield_output
self.answer_key = answer_key
self.config = config
self.pdl_config = pdl_config
self.cwd = cwd

def get_scope(self) -> ScopeType:
Expand Down Expand Up @@ -75,6 +77,7 @@ def run( # type: ignore # noqa: C901
yield_result=self.yield_output,
yield_background=self.yield_output,
cwd=self.cwd,
event_loop=self.pdl_config["event_loop"], # type: ignore
)
scope = self.get_scope()

Expand Down
5 changes: 4 additions & 1 deletion src/pdl/optimize/pdl_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@ def score(self, document: str, ground_truth: Any) -> float:
args:
document: ${{ document }}
ground_truth: ${{ ground_truth }}"""
result = exec_str(prog=prog, scope=scope, output="result")

result = exec_str(
prog=prog, config=self.pdl_config, scope=scope, output="result"
)

if isinstance(result, str):
result = result.strip()
Expand Down
11 changes: 11 additions & 0 deletions src/pdl/optimize/pdl_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
from pdl.optimize.optimizer_evaluator import OptimizerEvaluator
from pdl.optimize.pdl_evaluator import PdlEvaluator
from pdl.optimize.util import CandidateResult, TrialOutput, console, execute_threads
from pdl.pdl import InterpreterConfig
from pdl.pdl_ast import AdvancedBlockType, DataBlock, Program
from pdl.pdl_dumper import dump_program_exclude_internals
from pdl.pdl_scheduler import create_event_loop_thread

warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
FORMAT = "%(message)s"
Expand All @@ -36,6 +38,8 @@
logger.addHandler(RichHandler())
rng = default_rng()

_LOOP = create_event_loop_thread()


def resave_pdl(input_path: Path, output_path: Path, state: dict) -> int:
with (input_path.open(encoding="utf-8") as pdl,):
Expand Down Expand Up @@ -73,6 +77,7 @@ def __init__(
self,
dataset: DatasetDict,
config: OptimizationConfig,
pdl_config: InterpreterConfig,
trial_thread: type[OptimizerEvaluator],
yield_output: bool,
experiment_path: Path,
Expand All @@ -81,6 +86,7 @@ def __init__(
self.yield_output = yield_output

self.config = config
self.pdl_config = pdl_config
self.pdl_path = Path(config.pdl_path)
self.parallelism = config.parallelism
self.num_demonstrations = config.num_demonstrations
Expand Down Expand Up @@ -631,6 +637,7 @@ def evaluate(
timeout=self.timeout,
yield_output=self.yield_output,
config=self.config,
pdl_config=self.pdl_config,
cwd=pdl_file_parent,
),
)
Expand Down Expand Up @@ -827,13 +834,17 @@ def run_optimizer() -> int:
print(f"Unknown dataset: {config.dataset}")
sys.exit(1)

pdl_config = InterpreterConfig()
pdl_config["event_loop"] = _LOOP

# Create optimizer instance
optimizer = PDLOptimizer(
dataset=dataset,
trial_thread=PdlEvaluator,
yield_output=args.yield_output,
experiment_path=args.experiments_path,
config=config,
pdl_config=pdl_config,
)
optimizer.run()
return 0
22 changes: 22 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
from examples.optimizer.mbpp_evaluator import MBPPEvaluator
from pdl.optimize.config_parser import OptimizationConfig
from pdl.optimize.pdl_optimizer import PDLOptimizer
from pdl.pdl import InterpreterConfig
from pdl.pdl_scheduler import create_event_loop_thread

_LOOP = create_event_loop_thread()


def test_gsm8k_cot():
Expand Down Expand Up @@ -351,12 +355,17 @@ def test_gsm8k_cot():
"test": Dataset.from_list(gsm8k["test"]),
},
)

pdl_config = InterpreterConfig()
pdl_config["event_loop"] = _LOOP

optim = PDLOptimizer(
dataset=gsm8k,
trial_thread=Gsm8kEvaluator,
yield_output=True,
experiment_path=Path("test_experiments"),
config=config,
pdl_config=pdl_config,
)

result = optim.run()
Expand Down Expand Up @@ -712,12 +721,17 @@ def run_optimizer_gsm8k(pattern, num_demonstrations=0):
"test": Dataset.from_list(gsm8k["test"]),
},
)

pdl_config = InterpreterConfig()
pdl_config["event_loop"] = _LOOP

optim = PDLOptimizer(
dataset=gsm8k,
trial_thread=Gsm8kEvaluator,
yield_output=True,
experiment_path=Path("test_experiments"),
config=config,
pdl_config=pdl_config,
)

result = optim.run()
Expand Down Expand Up @@ -1078,12 +1092,16 @@ def run_optimizer_fever(pattern, num_demonstrations=0):
},
)

pdl_config = InterpreterConfig()
pdl_config["event_loop"] = _LOOP

optim = PDLOptimizer(
dataset=fever, # pyright: ignore
trial_thread=FEVEREvaluator,
yield_output=True,
experiment_path=Path("test_experiments"),
config=config,
pdl_config=pdl_config,
)

result = optim.run()
Expand Down Expand Up @@ -1123,12 +1141,16 @@ def run_optimizer_mbpp(pattern, num_demonstrations=0):
"../prompt-declaration-language-merge/var/mbpp_trajectified",
)

pdl_config = InterpreterConfig()
pdl_config["event_loop"] = _LOOP

optim = PDLOptimizer(
dataset=mbpp_dataset, # pyright: ignore
trial_thread=MBPPEvaluator,
yield_output=True,
experiment_path=Path("test_experiments"),
config=config,
pdl_config=pdl_config,
)

result = optim.run()
Expand Down