@@ -242,3 +242,116 @@ def __mul__(self, scalar: Any) -> "KRR":
242242 new = copy .deepcopy (self )
243243 self .betas_ *= scalar
244244 return new
245+
246+ def to_onnx (self ) -> Any :
247+ from onnx import numpy_helper
248+ from onnx .checker import check_model
249+ from onnx .helper import (
250+ make_graph ,
251+ make_model ,
252+ make_node ,
253+ make_tensor_value_info ,
254+ np_dtype_to_tensor_dtype ,
255+ )
256+
257+ assert self .X_train .dtype == self .betas_ .dtype
258+
259+ def make_constant_node (value : cn .array , name : str ) -> Any :
260+ return make_node (
261+ "Constant" ,
262+ inputs = [],
263+ value = numpy_helper .from_array (value , name = name ),
264+ outputs = [name ],
265+ )
266+
267+ nodes = []
268+
269+ # model constants
270+ betas = numpy_helper .from_array (self .betas_ .__array__ (), name = "betas" )
271+ X_train = numpy_helper .from_array (self .X_train .__array__ (), name = "X_train" )
272+
273+ # pred inputs
274+ X = make_tensor_value_info (
275+ "X" ,
276+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
277+ [None , self .X_train .shape [1 ]],
278+ )
279+ pred = make_tensor_value_info (
280+ "pred" ,
281+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
282+ [None , self .betas_ .shape [1 ]],
283+ )
284+
285+ # exanded l2 distance
286+ # distance = np.sum(X**2, axis=1)[:, np.newaxis] - 2 * np.dot(X, self.X_train.T)
287+ # + np.sum(self.X_train**2, axis=1)
288+ make_tensor_value_info (
289+ "XX" , np_dtype_to_tensor_dtype (self .betas_ .dtype ), [None ]
290+ )
291+ make_tensor_value_info (
292+ "YY" ,
293+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
294+ [self .X_train .shape [0 ], 1 ],
295+ )
296+ make_tensor_value_info (
297+ "XY_reshaped" ,
298+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
299+ [1 , self .X_train .shape [0 ]],
300+ )
301+ make_tensor_value_info (
302+ "XY" ,
303+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
304+ [None , self .X_train .shape [0 ]],
305+ )
306+ nodes .append (make_constant_node (np .array ([1 ]), "axis1" ))
307+ nodes .append (make_node ("ReduceSumSquare" , ["X" , "axis1" ], ["XX" ]))
308+ nodes .append (make_node ("Gemm" , ["X" , "X_train" ], ["XY" ], alpha = - 2.0 , transB = 1 ))
309+ nodes .append (make_node ("ReduceSumSquare" , ["X_train" , "axis1" ], ["YY" ]))
310+ nodes .append (make_constant_node (np .array ([1 , - 1 ]), "reshape" ))
311+ nodes .append (make_node ("Reshape" , ["YY" , "reshape" ], ["YY_reshaped" ]))
312+ nodes .append (make_node ("Add" , ["XX" , "XY" ], ["add0" ]))
313+ make_tensor_value_info (
314+ "l2" ,
315+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
316+ [None , self .X_train .shape [0 ]],
317+ )
318+ nodes .append (make_node ("Add" , ["YY_reshaped" , "add0" ], ["l2" ]))
319+ nodes .append (make_constant_node (np .array ([0.0 ], self .betas_ .dtype ), "zero" ))
320+ make_tensor_value_info (
321+ "l2_clipped" ,
322+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
323+ [None , self .X_train .shape [0 ]],
324+ )
325+ nodes .append (make_node ("Max" , ["l2" , "zero" ], ["l2_clipped" ]))
326+
327+ # RBF kernel
328+ # K = np.exp(-distance / (2 * self.sigma**2))
329+ make_tensor_value_info (
330+ "rbf0" ,
331+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
332+ [None , self .X_train .shape [0 ]],
333+ )
334+ if self .sigma is None :
335+ raise ValueError ("sigma is None. Has fit been called?" )
336+ nodes .append (
337+ make_constant_node (
338+ np .array ([- 2.0 * self .sigma ** 2 ], self .betas_ .dtype ), "denominator"
339+ )
340+ )
341+ nodes .append (make_node ("Div" , ["l2_clipped" , "denominator" ], ["rbf0" ]))
342+ make_tensor_value_info (
343+ "K" ,
344+ np_dtype_to_tensor_dtype (self .betas_ .dtype ),
345+ [None , self .X_train .shape [0 ]],
346+ )
347+ nodes .append (make_node ("Exp" , ["rbf0" ], ["K" ]))
348+
349+ # prediction
350+ # pred = np.dot(K, self.betas_)
351+ nodes .append (make_node ("MatMul" , ["K" , "betas" ], ["pred" ]))
352+ graph = make_graph (
353+ nodes , "legateboost.model.KRR" , [X ], [pred ], [betas , X_train ]
354+ )
355+ onnx_model = make_model (graph )
356+ check_model (onnx_model )
357+ return onnx_model
0 commit comments