@@ -16,8 +16,12 @@ def _cat_one(t):
1616
1717class _CastBatchedDisjointBase (Layer ):
1818
19- def __init__ (self , reverse_indices : bool = False , dtype_batch : str = "int64" , dtype_index = None ,
20- padded_disjoint : bool = False , uses_mask : bool = False ,
19+ def __init__ (self ,
20+ reverse_indices : bool = False ,
21+ dtype_batch : str = "int64" ,
22+ dtype_index = None ,
23+ padded_disjoint : bool = False ,
24+ uses_mask : bool = False ,
2125 static_batched_node_output_shape : tuple = None ,
2226 static_batched_edge_output_shape : tuple = None ,
2327 remove_padded_disjoint_from_batched_output : bool = True ,
@@ -29,20 +33,26 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
2933 dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
3034 dtype_index (str): Dtype for index tensor. Default is None.
3135 padded_disjoint (bool): Whether to keep padding in disjoint representation. Default is False.
36+ Not used for ragged arguments.
3237 uses_mask (bool): Whether the padding is marked by a boolean mask or by a length tensor, counting the
3338 non-padded nodes from index 0. Default is False.
39+ Not used for ragged arguments.
3440 static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None.
41+ Not used for ragged arguments.
3542 static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None.
43+ Not used for ragged arguments.
3644 remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output
3745 in case of padding.
46+ Not used for ragged arguments.
3847 """
3948 super (_CastBatchedDisjointBase , self ).__init__ (** kwargs )
4049 self .reverse_indices = reverse_indices
4150 self .dtype_index = dtype_index
4251 self .dtype_batch = dtype_batch
4352 self .uses_mask = uses_mask
4453 self .padded_disjoint = padded_disjoint
45- self .supports_jit = padded_disjoint
54+ if padded_disjoint :
55+ self .supports_jit = True
4656 self .static_batched_node_output_shape = static_batched_node_output_shape
4757 self .static_batched_edge_output_shape = static_batched_edge_output_shape
4858 self .remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output
@@ -536,31 +546,7 @@ def call(self, inputs: list, **kwargs):
536546CastBatchedGraphStateToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
537547
538548
539- class _CastRaggedToDisjointBase (Layer ):
540-
541- def __init__ (self , reverse_indices : bool = False , dtype_batch : str = "int64" , dtype_index = None , ** kwargs ):
542- r"""Initialize layer.
543-
544- Args:
545- reverse_indices (bool): Whether to reverse index order. Default is False.
546- dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
547- dtype_index (str): Dtype for index tensor. Default is None.
548- """
549- super (_CastRaggedToDisjointBase , self ).__init__ (** kwargs )
550- self .reverse_indices = reverse_indices
551- self .dtype_index = dtype_index
552- self .dtype_batch = dtype_batch
553- # self.supports_jit = False
554-
555- def get_config (self ):
556- """Get config dictionary for this layer."""
557- config = super (_CastRaggedToDisjointBase , self ).get_config ()
558- config .update ({"reverse_indices" : self .reverse_indices , "dtype_batch" : self .dtype_batch ,
559- "dtype_index" : self .dtype_index })
560- return config
561-
562-
563- class CastRaggedAttributesToDisjoint (_CastRaggedToDisjointBase ):
549+ class CastRaggedAttributesToDisjoint (_CastBatchedDisjointBase ):
564550
565551 def __init__ (self , ** kwargs ):
566552 super (CastRaggedAttributesToDisjoint , self ).__init__ (** kwargs )
@@ -598,10 +584,10 @@ def call(self, inputs, **kwargs):
598584 return decompose_ragged_tensor (inputs , batch_dtype = self .dtype_batch )
599585
600586
601- CastRaggedAttributesToDisjoint .__init__ .__doc__ = _CastRaggedToDisjointBase .__init__ .__doc__
587+ CastRaggedAttributesToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
602588
603589
604- class CastRaggedIndicesToDisjoint (_CastRaggedToDisjointBase ):
590+ class CastRaggedIndicesToDisjoint (_CastBatchedDisjointBase ):
605591
606592 def __init__ (self , ** kwargs ):
607593 super (CastRaggedIndicesToDisjoint , self ).__init__ (** kwargs )
@@ -685,10 +671,10 @@ def call(self, inputs, **kwargs):
685671 return [nodes_flatten , disjoint_indices , graph_id_node , graph_id_edge , node_id , edge_id , node_len , edge_len ]
686672
687673
688- CastRaggedIndicesToDisjoint .__init__ .__doc__ = _CastRaggedToDisjointBase .__init__ .__doc__
674+ CastRaggedIndicesToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
689675
690676
691- class CastDisjointToRaggedAttributes (_CastRaggedToDisjointBase ):
677+ class CastDisjointToRaggedAttributes (_CastBatchedDisjointBase ):
692678
693679 def __init__ (self , ** kwargs ):
694680 super (CastDisjointToRaggedAttributes , self ).__init__ (** kwargs )
@@ -713,4 +699,4 @@ def call(self, inputs, **kwargs):
713699 raise NotImplementedError ()
714700
715701
716- CastDisjointToRaggedAttributes .__init__ .__doc__ = CastDisjointToRaggedAttributes .__init__ .__doc__
702+ CastDisjointToRaggedAttributes .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
0 commit comments