-
Notifications
You must be signed in to change notification settings - Fork 17
#343: Support collections of tensors in args/kwargs for compile #701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()): | ||
input_names.add(name) | ||
result = {} | ||
for key, value in arg.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if the keys are not strings/implicitly convertible to strings?
for nested_name in sorted(trace_input_map.keys()): | ||
if nested_name.startswith(f"{name}.") or nested_name.startswith(f"{name}["): | ||
nested_tensors.append(trace_input_map[nested_name].trace_tensor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just populate the trace_input_map
as we call process_arg
instead of doing this extra step?
|
||
# Handle containers of InputInfo objects | ||
if isinstance(arg, dict): | ||
if any(isinstance(v, (InputInfo, DimensionInputInfo)) for v in arg.values()): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check needs to be removed to support nested collections. Same for lists below.
else: | ||
return [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When is this branch reached?
if name_prefix in input_info_names: | ||
return [tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can pull this out to the top of the extract_recursive
method and that'll also let you drop the checks on lines 204 and 214 (i.e. you can unconditionally make the recursive call).
No description provided.