@@ -110,24 +110,6 @@ def compute_input_schema(
110110 root_schema , parents_schema , deps_schema , selector
111111 )
112112
113- if len (parents_schema .column_schemas ) > 1 :
114- raise ValueError (
115- "More than one input has been detected for this node,"
116- / f"inputs received: { input_schema .column_names } "
117- )
118- if len (deps_schema .column_schemas ) > 1 :
119- raise ValueError (
120- "More than one dependency input has been detected"
121- / f"for this node, inputs received: { input_schema .column_names } "
122- )
123-
124- # 1 for deps and 1 for parents
125- if len (input_schema .column_schemas ) > 2 :
126- raise ValueError (
127- "More than one input has been detected for this node,"
128- / f"inputs received: { input_schema .column_names } "
129- )
130-
131113 self ._input_col = parents_schema .column_names [0 ]
132114 self ._filter_out_col = deps_schema .column_names [0 ]
133115
@@ -157,6 +139,27 @@ def compute_output_schema(
157139 """
158140 return Schema ([ColumnSchema ("filtered_ids" , dtype = np .int32 , is_list = False )])
159141
142+ def validate_schemas (
143+ self , parents_schema , deps_schema , input_schema , output_schema , strict_dtypes = False
144+ ):
145+ if len (parents_schema .column_schemas ) > 1 :
146+ raise ValueError (
147+ "More than one input has been detected for this node,"
148+ / f"inputs received: { input_schema .column_names } "
149+ )
150+ if len (deps_schema .column_schemas ) > 1 :
151+ raise ValueError (
152+ "More than one dependency input has been detected"
153+ / f"for this node, inputs received: { input_schema .column_names } "
154+ )
155+
156+ # 1 for deps and 1 for parents
157+ if len (input_schema .column_schemas ) > 2 :
158+ raise ValueError (
159+ "More than one input has been detected for this node,"
160+ / f"inputs received: { input_schema .column_names } "
161+ )
162+
160163 def transform (self , df : InferenceDataFrame ):
161164 """
162165 Transform input dataframe to output dataframe using function logic.
0 commit comments