Skip to content

Commit bbcc55c

Browse files
derekwan1meta-codesync[bot]
authored andcommitted
Classify authoring issues with model arch as input errors (#3837)
Summary: Pull Request resolved: #3837 Sometimes users have mistakes in model arch, where an expected prediction / weight name is not present in model output when torchrec expects it to be there. torchrec is open source code, so we can't change that assertion error to throw TrainingPlatforUserError. However, we can change it to raise a value error with a recognizable regex pattern so we can re-classify it in our logger as input error Differential Revision: D95415280 fbshipit-source-id: 50aa80e5092d5ed67918b675847614d89df23622
1 parent 9ac93c6 commit bbcc55c

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

torchrec/metrics/model_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ def parse_model_outputs(
6161
assert not weight_name, "weight name must be empty if prediction name is empty"
6262
return (labels, None, None)
6363
assert isinstance(labels, torch.Tensor)
64+
if prediction_name not in model_out:
65+
raise ValueError(f"Prediction name {prediction_name} not found in model output")
6466
predictions = model_out[prediction_name].squeeze()
6567
assert isinstance(predictions, torch.Tensor)
68+
if weight_name not in model_out:
69+
raise ValueError(f"Weight {weight_name} not found in model output")
6670
weights = model_out[weight_name].squeeze()
6771
assert isinstance(weights, torch.Tensor)
6872

@@ -81,25 +85,25 @@ def parse_model_outputs(
8185
if is_vector_valued_label_and_prediction:
8286
logger.warning(
8387
f"""
84-
Vector valued labels and predictions are provided.
88+
Vector valued labels and predictions are provided.
8589
86-
For vector valued label and prediction we should have shapes
90+
For vector valued label and prediction we should have shapes
8791
labels.shape: (batch_size, dim_vector_valued_label)
8892
predictions.shape: (batch_size, dim_vector_valued_prediction)
8993
weights.shape: (batch_size,)
9094
91-
The provided labels, predictions and weights comply with the conditions for vector valued labels and predictions.
92-
These conditions are:
95+
The provided labels, predictions and weights comply with the conditions for vector valued labels and predictions.
96+
These conditions are:
9397
1. labels.dim() == 2
9498
2. predictions.dim() == 2
9599
3. weights.dim() == 1
96100
4. labels.size()[0] == predictions.size()[0]
97101
5. labels.size()[0] == weights.size()[0]
98102
99-
The shapes of labels, predictions and weights are:
100-
labels.shape == {labels.shape},
101-
predictions.shape == {predictions.shape},
102-
weights.shape == {weights.shape}
103+
The shapes of labels, predictions and weights are:
104+
labels.shape == {labels.shape},
105+
predictions.shape == {predictions.shape},
106+
weights.shape == {weights.shape}
103107
"""
104108
)
105109
else:

0 commit comments

Comments
 (0)