-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathconfig.py
More file actions
281 lines (237 loc) · 10.9 KB
/
config.py
File metadata and controls
281 lines (237 loc) · 10.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
from pydantic import (
BaseModel,
PositiveInt,
NonNegativeInt,
NonNegativeFloat,
ConfigDict,
Field,
field_serializer,
)
from typing import List, Optional, Literal, Any, Set, Dict
from collections.abc import Callable, Mapping
from functools import cache
from pathlib import Path
from interactor import InteractorBasis
from basis import ModelLike, Stage
from mcap_data_loader.utils.basic import force_set_attr, ConstrainedIterable
from mcap_data_loader.basis.cfgable import dump_or_repr
from mcap_data_loader.basis.data_loader import DataLoaderKey, MappingProtocol
IterUnit = Literal["epoch", "step", "sample", "minute"]
MODEL_CONFIG = ConfigDict(validate_assignment=True, extra="forbid")
class CommonConfig(BaseModel, frozen=True):
"""Common configuration parameters.
These configurations are applicable to both training and testing
"""
stage: Stage
"""Stage of this run."""
root_dir: Path
"""Root directory for saving logs and checkpoints."""
checkpoints_dir: Path
"""Directory for saving model checkpoints."""
checkpoint_path: Optional[Path] = None
"""Path to a specific model checkpoint.
If None, when the stage is training, it will be automatically assigned based on existing checkpoints in the form of
`<next_id>`, where `<next_id>` is the next integer not used in the `checkpoints_dir`; when the stage is not training, it will be the last checkpoint (the one with the highest integer).
"""
seed: Optional[int] = 42
"""Random seed for reproducibility."""
loss_fn: Callable[[Any, Any], Any]
"""Loss function to use."""
tb_log_dir: Path = Path("tensorboard")
ewc_model: Optional[Path] = None
ewc_lambda: float = 100.0
@force_set_attr
def model_post_init(self, context):
def process_path(
path: Optional[Path], relative_to: Path = self.root_dir
) -> Optional[Path]:
if path is None:
return None
return path if (path.is_absolute() or path.exists()) else relative_to / path
self.checkpoints_dir = process_path(self.checkpoints_dir)
if self.checkpoint_path is None:
ckpt_dir = self.checkpoints_dir
ids = [int(p.stem) for p in ckpt_dir.iterdir() if p.stem.isdigit()]
max_id = max(ids) if ids else -1
if self.stage is Stage.TRAIN:
next_id = max_id + 1
else:
if max_id < 0:
raise ValueError(
f"No checkpoints found in {ckpt_dir} for stage {self.stage}"
)
next_id = max_id
self.checkpoint_path = str(next_id)
self.checkpoint_path = process_path(self.checkpoint_path, self.checkpoints_dir)
self.tb_log_dir = process_path(self.tb_log_dir)
class TestConfig(BaseModel, frozen=True):
"""Configuration for testing."""
save_results: bool = True
"""Whether to save the test results."""
show_plot: bool = False
"""Whether to show plots of the results."""
rollout_steps: PositiveInt = 1
"""Number of steps for rollout prediction (>=1)."""
class KeyConditionConfig(BaseModel, frozen=True):
"""Configuration for specifying stopping criteria based on keys."""
necessary: Set[str] = set()
"""Set of keys that are necessary conditions for stopping."""
sufficient: Set[str] = set()
"""Set of keys that are sufficient conditions for stopping."""
@property
def keys(self) -> Set[str]:
return self.necessary | self.sufficient
class IntermittentConfig(BaseModel, frozen=True):
"""Configuration for interval-based operations."""
unit: IterUnit = "minute"
"""Unit of the interval."""
interval: PositiveInt = 1
"""Interval in the specified unit."""
maximum: NonNegativeInt = 0
"""Maximum number of times to perform the operation. If set to 0, there is no limit."""
class SnapshotConfig(IntermittentConfig):
"""Configuration for snapshot taking."""
keys: Set[str] = set()
"""Set of metric keys to include in the snapshot."""
class TrainIterationConfig(BaseModel, frozen=True):
"""Configuration for training iteration.
This class defines various stopping and continuation criteria for a training loop.
Any criterion set to its default "no limit" value (typically 0 or 0.0) is effectively disabled.
Training will stop when any of the sufficient conditions are met or all necessary conditions are satisfied.
"""
model_config = MODEL_CONFIG
patience: NonNegativeInt = 0
"""Number of consecutive epochs with no improvement in the train or val loss before early stopping is triggered.
Set to 0 to disable early stopping."""
max_epoch: NonNegativeInt = 0
"""Maximum number of epochs allowed for training."""
min_epoch: NonNegativeInt = 0
"""Minimum number of epochs that must be completed before any stopping
condition (e.g., patience or loss thresholds) is evaluated."""
max_step: NonNegativeInt = 0
"""Maximum number of training steps (batches processed) allowed."""
min_step: NonNegativeInt = 0
"""Minimum number of training steps that must be completed before stopping
conditions are evaluated."""
max_sample: NonNegativeInt = 0
"""Maximum number of training samples processed (across all epochs)."""
min_sample: NonNegativeInt = 0
"""Minimum number of training samples that must be processed before stopping
conditions are evaluated."""
max_time: NonNegativeFloat = 0.0
"""Maximum wall-clock training and validation time in minutes. Training stops once exceeded.
Set to 0.0 for no time limit."""
min_time: NonNegativeFloat = 0.0
"""Minimum training and validation time in minutes that must elapse before any stopping
condition is considered."""
max_train_time: NonNegativeFloat = 0.0
"""Maximum training time in minutes (excluding validation). Training stops once exceeded."""
min_train_time: NonNegativeFloat = 0.0
"""Minimum training time in minutes that must elapse before any stopping condition is considered."""
max_train_loss: NonNegativeFloat = 0.0
"""Upper bound on training loss; training stops if loss exceeds this value."""
min_train_loss: NonNegativeFloat = 0.0
"""Lower bound on training loss; training will not stop if loss drops below this value."""
max_val_loss: NonNegativeFloat = 0.0
"""Upper bound on validation loss; training stops if validation loss exceeds this value."""
min_val_loss: NonNegativeFloat = 0.0
"""Lower bound on validation loss; training will not stop if validation loss drops below this value."""
conditions: KeyConditionConfig = KeyConditionConfig()
"""Configuration for stopping conditions."""
def model_post_init(self, context):
valid_keys = self.get_valid_keys()
if self.conditions.keys:
invalid_keys = self.conditions.keys - valid_keys
if invalid_keys:
raise ValueError(f"Invalid condition keys: {invalid_keys}")
min_in_suffi = self.conditions.sufficient & self.get_valid_keys("min")
if min_in_suffi:
raise ValueError(
f"min keys {min_in_suffi} can not be sufficient conditions"
)
else:
for key in self.model_fields_set & valid_keys:
if getattr(self, key) != 0:
if "min" in key:
self.conditions.necessary.add(key)
else:
self.conditions.sufficient.add(key)
@classmethod
@cache
def get_valid_keys(cls, matching: str = "") -> Set[str]:
keys = cls.model_fields.keys() - {"conditions", "iter_mode"}
if matching:
return {key for key in keys if matching in key}
return keys
class SaveModelConfig(BaseModel, frozen=True):
"""Configuration for saving the model."""
period: Optional[IntermittentConfig] = None
"""Configuration for periodic saving of the model."""
on_improve: List[Literal["train_loss", "val_loss"]] = ["val_loss"]
"""Saves the model when the specified metrics improve."""
maximum: List[NonNegativeInt] = [5]
"""Maximum number of saved models for each metric in on_improve."""
def model_post_init(self, context):
if len(self.on_improve) != len(self.maximum):
raise ValueError("Length of on_improve and maximum must be the same")
class TrainConfig(BaseModel, frozen=True):
"""Configuration for training."""
model_config = MODEL_CONFIG
task_id: PositiveInt = 1
"""Identifier for the training task."""
ewc_threshold: float = 1.0
"""Threshold value for EWC regularization."""
ewc_regularization: bool = False
"""Whether to apply EWC regularization during training."""
iteration: TrainIterationConfig = Field(default_factory=TrainIterationConfig)
"""Configuration for training iteration."""
snapshot: List[SnapshotConfig] = []
"""List of snapshot configurations."""
save_model: SaveModelConfig = SaveModelConfig()
"""Configuration for saving the model."""
class InferConfig(BaseModel, frozen=True):
"""Configuration for inference."""
model_config = MODEL_CONFIG
max_rollouts: NonNegativeInt = 0
"""Maximum number of rollouts to perform during inference."""
max_steps: NonNegativeInt = 0
"""Maximum number of steps to perform during inference."""
frequency: int = 0
"""The frequency (in steps) to send action commands.
0 means wait for any input after every step. Negative means no limit.
"""
frequency_inner: int = 0
"""The frequency (in each step loop) to send action commands."""
rollout_wait: Any = "input"
"""The wait method between rollouts. It can be 'input' to wait for user input,
or a numeric value indicating seconds to wait.
"""
start_rollout: NonNegativeInt = 0
"""The starting rollout index."""
class Config(CommonConfig):
"""Main configuration"""
model_config = ConfigDict(
validate_assignment=True, extra="forbid", arbitrary_types_allowed=True
)
model: ModelLike
"""Configuration for the model."""
data_loaders: MappingProtocol[DataLoaderKey, ConstrainedIterable]
"""Data loaders with names."""
interactor: Optional[InteractorBasis] = None
"""Configuration for the interactor."""
train: TrainConfig = TrainConfig()
"""Configuration for training."""
test: TestConfig = TestConfig()
"""Configuration for testing."""
infer: InferConfig = InferConfig()
"""Configuration for inference."""
extra: Dict[str, Any] = {}
"""Extra configuration parameters. It is useful
to store intermediate parameters in this field in the
hydra config file to avoid extra forbid error."""
@field_serializer("data_loaders", when_used="json")
def serialize_data_loaders(self, data_loaders: Mapping):
try:
return dump_or_repr(data_loaders)
except Exception:
return {key: repr(value) for key, value in data_loaders.items()}