-
Notifications
You must be signed in to change notification settings - Fork 339
Use lambda instead of functools.partial to create single-host config maker #1152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
I have not encountered this issue. How can I reproduce it? |
@@ -1011,7 +1011,6 @@ def make_single_host_config(base_config_name: str) -> SpmdTrainer.Config: | |||
return cfg | |||
|
|||
# Make single-host config | |||
make_single_host_config_func = functools.partial(make_single_host_config, config_name) | |||
config_map[f"{config_name}-single-host"] = make_single_host_config_func | |||
config_map[f"{config_name}-single-host"] = lambda: make_single_host_config(config_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, this won't work as is due to how closures work in Python?
config_map[f"{config_name}-single-host"] = lambda: make_single_host_config(config_name) | |
config_map[f"{config_name}-single-host"] = lambda config_name=config_name: make_single_host_config(config_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That should work just fine, like how we captured things in the idiomatic way of defining decorators.
Here is a simple test that captures the pattern:
def f():
x = 1
return lambda: print(x)
g = f()
g()
which correctly prints 1 when we called g()
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I missed something, but why doesn't the code in the PR exhibit the same problem as the following code:
fns = []
for i in range(3):
fns.append(lambda: i)
print( fns[0]() ) # Incorrectly prints 2 instead of 0.
I agree your snippet doesn't have this issue because the value being closed over is not modified.
This is only triggered if using one of the *-single-host configs, e.g. |
Yeah I have been using single-host config for my testing on TPU recently. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will defer to @apghml for approval.
Fixes #1151
An alternative solution is to use
functools.wraps
but that requires a custom config validator: