Skip to content

Commit bac77ff

Browse files
gag1jainfacebook-github-bot
authored andcommitted
Logging KT's key order warning only once (#2548)
Summary: this warning is very noisy, changing to print it only once instead of every time Reviewed By: TroyGarden Differential Revision: D65700079
1 parent 9a4d8a8 commit bac77ff

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

torchrec/sparse/jagged_tensor.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -3380,14 +3380,21 @@ def _kt_unflatten(
33803380
return KeyedTensor(context[0], context[1], values[0])
33813381

33823382

3383+
print_flatten_spec_warn = True
3384+
3385+
33833386
def _kt_flatten_spec(kt: KeyedTensor, spec: TreeSpec) -> List[torch.Tensor]:
33843387
_keys, _length_per_key = spec.context
33853388
# please read https://fburl.com/workplace/8bei5iju for more context,
33863389
# you can also consider use short_circuit_pytree_ebc_regroup with KTRegroupAsDict
3387-
logger.warning(
3388-
"KT's key order might change from spec from the torch.export, this could have perf impact. "
3389-
f"{kt.keys()} vs {_keys}"
3390-
)
3390+
global print_flatten_spec_warn
3391+
if print_flatten_spec_warn:
3392+
logger.warning(
3393+
"KT's key order might change from spec from the torch.export, this could have perf impact. "
3394+
f"{kt.keys()} vs {_keys}"
3395+
)
3396+
print_flatten_spec_warn = False
3397+
33913398
res = permute_multi_embedding([kt], [_keys])
33923399
return [res[0]]
33933400

0 commit comments

Comments
 (0)