Skip to content

Commit 15c6543

Browse files
authored
Return the original predicted scores from SoftmaxSampling (#290)
* Return the original predicted scores from `SoftmaxSampling` * Update the output schema of `SoftmaxSampling` to match * Fix output scores dtype
1 parent a6cb35c commit 15c6543

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

merlin/systems/dag/ops/softmax_sampling.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,12 @@ def compute_output_schema(
9999
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
100100
) -> Schema:
101101
"""Describe the operator's outputs"""
102-
return Schema([ColumnSchema("ordered_ids", dtype=np.int32, dims=(None, 1))])
102+
return Schema(
103+
[
104+
ColumnSchema("ordered_ids", dtype=np.int32, dims=(None, 1)),
105+
ColumnSchema("ordered_scores", dtype=np.float32, dims=(None, 1)),
106+
]
107+
)
103108

104109
def transform(
105110
self, col_selector: ColumnSelector, transformable: Transformable
@@ -133,6 +138,10 @@ def transform(
133138
# This is just bookkeeping to produce the final ordered list of recs
134139
sorted_indices = np.argsort(exponentials)
135140
topk_item_ids = candidate_ids[sorted_indices][: self.topk]
141+
topk_item_scores = predicted_scores[sorted_indices][: self.topk]
136142
ordered_item_ids = topk_item_ids.reshape(1, -1).T
143+
ordered_item_scores = topk_item_scores.reshape(1, -1).T
137144

138-
return type(transformable)({"ordered_ids": ordered_item_ids})
145+
return type(transformable)(
146+
{"ordered_ids": ordered_item_ids, "ordered_scores": ordered_item_scores}
147+
)

0 commit comments

Comments
 (0)