@@ -99,7 +99,12 @@ def compute_output_schema(
9999 self , input_schema : Schema , col_selector : ColumnSelector , prev_output_schema : Schema = None
100100 ) -> Schema :
101101 """Describe the operator's outputs"""
102- return Schema ([ColumnSchema ("ordered_ids" , dtype = np .int32 , dims = (None , 1 ))])
102+ return Schema (
103+ [
104+ ColumnSchema ("ordered_ids" , dtype = np .int32 , dims = (None , 1 )),
105+ ColumnSchema ("ordered_scores" , dtype = np .float32 , dims = (None , 1 )),
106+ ]
107+ )
103108
104109 def transform (
105110 self , col_selector : ColumnSelector , transformable : Transformable
@@ -133,6 +138,10 @@ def transform(
133138 # This is just bookkeeping to produce the final ordered list of recs
134139 sorted_indices = np .argsort (exponentials )
135140 topk_item_ids = candidate_ids [sorted_indices ][: self .topk ]
141+ topk_item_scores = predicted_scores [sorted_indices ][: self .topk ]
136142 ordered_item_ids = topk_item_ids .reshape (1 , - 1 ).T
143+ ordered_item_scores = topk_item_scores .reshape (1 , - 1 ).T
137144
138- return type (transformable )({"ordered_ids" : ordered_item_ids })
145+ return type (transformable )(
146+ {"ordered_ids" : ordered_item_ids , "ordered_scores" : ordered_item_scores }
147+ )
0 commit comments