|
| 1 | +import json |
1 | 2 | import random
|
2 | 3 | from dataclasses import dataclass
|
| 4 | +from pathlib import Path |
3 | 5 | from typing import Any, AsyncIterable, Dict, Iterable, List, Set
|
4 | 6 |
|
5 | 7 | from kilroy_module_server_py_sdk import (
|
6 | 8 | CategorizableBasedParameter,
|
7 | 9 | Configurable,
|
8 | 10 | Parameter,
|
| 11 | + Savable, |
9 | 12 | SerializableModel,
|
10 | 13 | classproperty,
|
11 | 14 | )
|
@@ -70,21 +73,52 @@ def parameters(cls) -> Set[Parameter]:
|
70 | 73 |
|
71 | 74 | async def build_default_state(self) -> State:
|
72 | 75 | params = Params(**self._kwargs)
|
73 |
| - sampler_cls = Sampler.for_category(params.sampler_type) |
74 |
| - sampler_params = params.samplers_params.get(params.sampler_type, {}) |
75 |
| - if issubclass(sampler_cls, Configurable): |
76 |
| - sampler = await sampler_cls.build(**sampler_params) |
77 |
| - await sampler.init() |
78 |
| - else: |
79 |
| - sampler = sampler_cls(**sampler_params) |
80 | 76 | return State(
|
81 |
| - sampler=sampler, |
| 77 | + sampler=await self.build_generic( |
| 78 | + Sampler, |
| 79 | + category=params.sampler_type, |
| 80 | + **params.samplers_params.get(params.sampler_type, {}), |
| 81 | + ), |
82 | 82 | samplers_params=params.samplers_params,
|
83 | 83 | contexts=params.contexts,
|
84 | 84 | max_length=params.max_length,
|
85 | 85 | batch_size=params.batch_size,
|
86 | 86 | )
|
87 | 87 |
|
| 88 | + async def save_state(self, state: State, directory: Path) -> None: |
| 89 | + state_dict = { |
| 90 | + "sampler_type": state.sampler.category, |
| 91 | + "samplers_params": state.samplers_params, |
| 92 | + "contexts": state.contexts, |
| 93 | + "max_length": state.max_length, |
| 94 | + "batch_size": state.batch_size, |
| 95 | + } |
| 96 | + if isinstance(state.sampler, Savable): |
| 97 | + await state.sampler.save(directory / "sampler") |
| 98 | + with open(directory / "state.json", "w") as f: |
| 99 | + json.dump(state_dict, f) |
| 100 | + |
| 101 | + async def load_saved_state(self, directory: Path) -> State: |
| 102 | + with open(directory / "state.json", "r") as f: |
| 103 | + state_dict = json.load(f) |
| 104 | + sampler_type = state_dict["sampler_type"] |
| 105 | + sampler_kwargs = { |
| 106 | + **self._kwargs.get("samplers_params", {}).get(sampler_type, {}), |
| 107 | + **state_dict["samplers_params"].get(sampler_type, {}), |
| 108 | + } |
| 109 | + return State( |
| 110 | + sampler=await self.load_generic( |
| 111 | + directory / "sampler", |
| 112 | + Sampler, |
| 113 | + category=state_dict["sampler_type"], |
| 114 | + **sampler_kwargs, |
| 115 | + ), |
| 116 | + samplers_params=state_dict["samplers_params"], |
| 117 | + contexts=state_dict["contexts"], |
| 118 | + max_length=state_dict["max_length"], |
| 119 | + batch_size=state_dict["batch_size"], |
| 120 | + ) |
| 121 | + |
88 | 122 | async def cleanup(self) -> None:
|
89 | 123 | async with self.state.write_lock() as state:
|
90 | 124 | if isinstance(state.sampler, Configurable):
|
|
0 commit comments