@@ -373,8 +373,9 @@ def __getitem__(self, grid) -> T:
373373
374374def serialize_specialization_data (name , signature , constants , attrs , options , key ):
375375 constants = {
376- key : str (value ) if value .__class__ .__name__ == "dtype" else
377- {"constexpr" : value .value } if value .__class__ .__name__ == "constexpr" else value
376+ key : str (value ) if value .__class__ .__name__ == "dtype" else {"constexpr" : value .value }
377+ if value .__class__ .__name__ == "constexpr" else {"jit_function" : f"{ value .module } :{ value .fn .__qualname__ } " }
378+ if value .__class__ .__name__ == "JITFunction" else value
378379 for key , value in constants .items ()
379380 }
380381
@@ -560,6 +561,9 @@ def _get_src(self):
560561 src = property (fget = _get_src , fset = _set_src )
561562
562563
564+ _triton_jit_function_registry = {}
565+
566+
563567@dataclass
564568class JitFunctionInfo :
565569 module : ModuleType
@@ -771,6 +775,8 @@ def __init__(self, fn, version=None, do_not_specialize=None, do_not_specialize_o
771775 self .do_not_specialize_on_alignment = do_not_specialize_on_alignment
772776 self ._repr = repr
773777 self .launch_metadata = launch_metadata
778+ # Register for simple deserialization of JITFunction constants
779+ _triton_jit_function_registry [f"{ self .module } :{ self .fn .__qualname__ } " ] = self
774780
775781 self .params = []
776782 for i , param in enumerate (self .signature .parameters .values ()):
@@ -805,12 +811,21 @@ def preload(self, specialization_data):
805811 f"Specialization data is for { deserialized_obj ['name' ]} but trying to preload for { self ._fn_name } " )
806812 constant_keys = map (tuple , deserialized_obj ['constant_keys' ])
807813 constant_vals = deserialized_obj ['constant_vals' ]
808- constexprs = {
809- key :
810- tl .dtype (value ) if tl .dtype .is_dtype (value ) else
811- tl .constexpr (value ['constexpr' ]) if isinstance (value , dict ) and 'constexpr' in value else value
812- for key , value in zip (constant_keys , constant_vals )
813- }
814+
815+ def _decode_constant (value ):
816+ if tl .dtype .is_dtype (value ):
817+ return tl .dtype (value )
818+ if isinstance (value , dict ):
819+ if 'constexpr' in value :
820+ return tl .constexpr (value ['constexpr' ])
821+ if 'jit_function' in value :
822+ jf_key = value ['jit_function' ]
823+ if jf_key in _triton_jit_function_registry :
824+ return _triton_jit_function_registry [jf_key ]
825+ raise RuntimeError (f"Unable to resolve JITFunction { jf_key } for preload" )
826+ return value
827+
828+ constexprs = {key : _decode_constant (value ) for key , value in zip (constant_keys , constant_vals )}
814829 attrs_keys = map (tuple , deserialized_obj ['attrs_keys' ])
815830 attrs_vals = deserialized_obj ['attrs_vals' ]
816831 attrs = dict (zip (attrs_keys , attrs_vals ))
0 commit comments