Skip to content

Commit 68e16b2

Browse files
authored
Partials - Dynamic Config Dataclasses for arbitrary callables (#156)
* Partials feature POC Signed-off-by: Fabrice Normandin <[email protected]> * Functools black magic, partials are pickleable Signed-off-by: Fabrice Normandin <[email protected]> * Partials feature POC Signed-off-by: Fabrice Normandin <[email protected]> * Functools black magic, partials are pickleable Signed-off-by: Fabrice Normandin <[email protected]> * Add postponed annotation version of test Signed-off-by: Fabrice Normandin <[email protected]> * Apply pre-commit hooks to partial.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix example, rename typevars Signed-off-by: Fabrice Normandin <[email protected]> * Add comments in the partials_example.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix the partials_example.py file Signed-off-by: Fabrice Normandin <[email protected]> * Add `nested_partial` helper function Signed-off-by: Fabrice Normandin <[email protected]> * Tweak the partials_example.py Signed-off-by: Fabrice Normandin <[email protected]> * Fix issue with using functools.partial[T] in py37 Signed-off-by: Fabrice Normandin <[email protected]> * Adding some more tests for Partial Signed-off-by: Fabrice Normandin <[email protected]> * Simplify `partial.py` a bit Signed-off-by: Fabrice Normandin <[email protected]> * Add test from PR suggestion, add `sp.config_for` Signed-off-by: Fabrice Normandin <[email protected]> * Fix missing ``` in docstring Signed-off-by: Fabrice Normandin <[email protected]> * Remove torch.optim.SGD fix an old BUG comment Signed-off-by: Fabrice Normandin <[email protected]> * Improve docstring of `config_for` Signed-off-by: Fabrice Normandin <[email protected]> * Add `adjust_default` in __all__ Signed-off-by: Fabrice Normandin <[email protected]> * Fix import issue in test_partial_postponed.py Signed-off-by: Fabrice Normandin <[email protected]> * Remove kw_only which appeared in py>=3.9 Signed-off-by: Fabrice Normandin <[email protected]> * Update regression files (idk why though?!) Signed-off-by: Fabrice Normandin <[email protected]> * Actually use a frozen instance as default in test Signed-off-by: Fabrice Normandin <[email protected]> * Add `frozen` argument that gets passed through Signed-off-by: Fabrice Normandin <[email protected]> * Fix doctest Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]> Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 65b07f3 commit 68e16b2

14 files changed

+647
-7
lines changed

examples/partials/README.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Partials - Configuring arbitrary classes / callables
2+

examples/partials/partials_example.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass
4+
5+
from simple_parsing import ArgumentParser
6+
from simple_parsing.helpers import subgroups
7+
from simple_parsing.helpers.partial import Partial, config_for
8+
9+
10+
# Suppose we want to choose between the Adam and SGD optimizers from PyTorch:
11+
# (NOTE: We don't import pytorch here, so we just create the types to illustrate)
12+
class Optimizer:
13+
def __init__(self, params):
14+
...
15+
16+
17+
class Adam(Optimizer):
18+
def __init__(
19+
self,
20+
params,
21+
lr: float = 3e-4,
22+
beta1: float = 0.9,
23+
beta2: float = 0.999,
24+
eps: float = 1e-08,
25+
):
26+
self.params = params
27+
self.lr = lr
28+
self.beta1 = beta1
29+
self.beta2 = beta2
30+
self.eps = eps
31+
32+
33+
class SGD(Optimizer):
34+
def __init__(
35+
self,
36+
params,
37+
lr: float = 3e-4,
38+
weight_decay: float | None = None,
39+
momentum: float = 0.9,
40+
eps: float = 1e-08,
41+
):
42+
self.params = params
43+
self.lr = lr
44+
self.weight_decay = weight_decay
45+
self.momentum = momentum
46+
self.eps = eps
47+
48+
49+
# Dynamically create a dataclass that will be used for the above type:
50+
# NOTE: We could use Partial[Adam] or Partial[Optimizer], however this would treat `params` as a
51+
# required argument.
52+
# AdamConfig = Partial[Adam] # would treat 'params' as a required argument.
53+
# SGDConfig = Partial[SGD] # same here
54+
AdamConfig: type[Partial[Adam]] = config_for(Adam, ignore_args="params")
55+
SGDConfig: type[Partial[SGD]] = config_for(SGD, ignore_args="params")
56+
57+
58+
@dataclass
59+
class Config:
60+
61+
# Which optimizer to use.
62+
optimizer: Partial[Optimizer] = subgroups(
63+
{
64+
"sgd": SGDConfig,
65+
"adam": AdamConfig,
66+
},
67+
default_factory=AdamConfig,
68+
)
69+
70+
71+
parser = ArgumentParser()
72+
parser.add_arguments(Config, "config")
73+
args = parser.parse_args()
74+
75+
76+
config: Config = args.config
77+
print(config)
78+
expected = "Config(optimizer=AdamConfig(lr=0.0003, beta1=0.9, beta2=0.999, eps=1e-08))"
79+
80+
my_model_parameters = [123] # nn.Sequential(...).parameters()
81+
82+
optimizer = config.optimizer(params=my_model_parameters)
83+
print(vars(optimizer))
84+
expected += """
85+
{'params': [123], 'lr': 0.0003, 'beta1': 0.9, 'beta2': 0.999, 'eps': 1e-08}
86+
"""

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
docstring-parser~=0.15
2-
typing_extensions>=4.3.0
2+
typing_extensions>=4.5.0

simple_parsing/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
from .decorators import main
77
from .help_formatter import SimpleHelpFormatter
88
from .helpers import (
9+
Partial,
910
Serializable,
1011
choice,
12+
config_for,
1113
field,
1214
flag,
1315
list_field,
@@ -31,6 +33,7 @@
3133
"ArgumentGenerationMode",
3234
"ArgumentParser",
3335
"choice",
36+
"config_for",
3437
"ConflictResolution",
3538
"DashVariant",
3639
"field",
@@ -44,6 +47,7 @@
4447
"parse_known_args",
4548
"parse",
4649
"ParsingError",
50+
"Partial",
4751
"replace",
4852
"Serializable",
4953
"SimpleHelpFormatter",

simple_parsing/helpers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from .fields import *
33
from .flatten import FlattenedAccess
44
from .hparams import HyperParameters
5+
from .partial import Partial, config_for
56
from .serialization import FrozenSerializable, Serializable, SimpleJsonEncoder, encode
67

78
try:
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import functools
2+
from typing import Any, Generic, TypeVar
3+
4+
_T = TypeVar("_T")
5+
6+
7+
class npartial(functools.partial, Generic[_T]):
8+
"""Partial that also invokes partials in args and kwargs before feeding them to the function.
9+
10+
Useful for creating nested partials, e.g.:
11+
12+
13+
>>> from dataclasses import dataclass, field
14+
>>> @dataclass
15+
... class Value:
16+
... v: int = 0
17+
>>> @dataclass
18+
... class ValueWrapper:
19+
... value: Value
20+
...
21+
>>> from functools import partial
22+
>>> @dataclass
23+
... class WithRegularPartial:
24+
... wrapped: ValueWrapper = field(
25+
... default_factory=partial(ValueWrapper, value=Value(v=123)),
26+
... )
27+
28+
Here's the problem: This here is BAD! They both share the same instance of Value!
29+
30+
>>> WithRegularPartial().wrapped.value is WithRegularPartial().wrapped.value
31+
True
32+
>>> @dataclass
33+
... class WithNPartial:
34+
... wrapped: ValueWrapper = field(
35+
... default_factory=npartial(ValueWrapper, value=npartial(Value, v=123)),
36+
... )
37+
>>> WithNPartial().wrapped.value is WithNPartial().wrapped.value
38+
False
39+
40+
This is fine now!
41+
"""
42+
43+
def __call__(self, *args: Any, **keywords: Any) -> _T:
44+
keywords = {**self.keywords, **keywords}
45+
args = self.args + args
46+
args = tuple(arg() if isinstance(arg, npartial) else arg for arg in args)
47+
keywords = {k: v() if isinstance(v, npartial) else v for k, v in keywords.items()}
48+
return self.func(*args, **keywords)

0 commit comments

Comments
 (0)