@@ -99,7 +99,12 @@ def compute_output_schema(
99
99
self , input_schema : Schema , col_selector : ColumnSelector , prev_output_schema : Schema = None
100
100
) -> Schema :
101
101
"""Describe the operator's outputs"""
102
- return Schema ([ColumnSchema ("ordered_ids" , dtype = np .int32 , dims = (None , 1 ))])
102
+ return Schema (
103
+ [
104
+ ColumnSchema ("ordered_ids" , dtype = np .int32 , dims = (None , 1 )),
105
+ ColumnSchema ("ordered_scores" , dtype = np .float32 , dims = (None , 1 )),
106
+ ]
107
+ )
103
108
104
109
def transform (
105
110
self , col_selector : ColumnSelector , transformable : Transformable
@@ -133,6 +138,10 @@ def transform(
133
138
# This is just bookkeeping to produce the final ordered list of recs
134
139
sorted_indices = np .argsort (exponentials )
135
140
topk_item_ids = candidate_ids [sorted_indices ][: self .topk ]
141
+ topk_item_scores = predicted_scores [sorted_indices ][: self .topk ]
136
142
ordered_item_ids = topk_item_ids .reshape (1 , - 1 ).T
143
+ ordered_item_scores = topk_item_scores .reshape (1 , - 1 ).T
137
144
138
- return type (transformable )({"ordered_ids" : ordered_item_ids })
145
+ return type (transformable )(
146
+ {"ordered_ids" : ordered_item_ids , "ordered_scores" : ordered_item_scores }
147
+ )
0 commit comments