Hi, I have a question about the expected dimensions for features, targets, and preds in the ConR function. I noticed the use of operations like torch.einsum and flatten, but it's not clear what the intended shapes of these inputs should be.
Currently, I am using the following dimensions:
features: [B, L, D] where B is the batch size, L is the sequence length, and D is the feature dimension.
targets: [B, L]
preds: [B, L]
Could you confirm whether these dimensions align with the function’s intended input? Also, are there any specific assumptions regarding L or D that we should be aware of when computing the loss?
Thanks for your clarification!