Description
Describe the bug
The documentation for TensorDictModule
is not quite correct. In the param spec, it says:
out_keys
(iterable of str) – keys to be written to the input tensordict. The length of out_keys must match the number of tensors returned by the embedded module. Using “_” as a key avoid writing tensor to output
Specifically, the usage of the word "must" is incorrect as it implies failure to match the length of the module output with the number of supplied keys will produce an error. But the TensorDictModule
actually silently ignores this possible error via the zip()
call in tensordict/nn/common.py#L1058
.
Solution
I'm not really sure what the right move is here. The current approach works, but the documentation should be changed to reflect the fact that this "error" will be silently ignored. I personally think it would be nice if this actually did raise a ValueError
, but that likely wouldn't be backwards compatible. Just wanted to bring this up!