@@ -2680,3 +2680,76 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
26802680 encoding3d_2 = model_3 (coord , node_type )
26812681 assert encoding3d_1 .shape == (bsz , max_num_nodes , max_num_nodes , num_heads )
26822682 assert encoding3d_2 .shape == (bsz , max_num_nodes , max_num_nodes , num_heads )
2683+
2684+
2685+ @pytest .mark .parametrize ("residual" , [True , False ])
2686+ def test_conv_with_zero_nodes_bugfix_7894 (residual ):
2687+ """Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a
2688+ specific node type would cause an error due to empty tensors.
2689+ This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle
2690+ such cases without raising errors.
2691+ """
2692+ # Create a heterogeneous graph with zero nodes in the "tag" type
2693+ user_item_src = torch .tensor ([0 , 1 , 2 ])
2694+ user_item_dst = torch .tensor ([4 , 5 , 6 ])
2695+
2696+ user_tag_src = torch .tensor ([], dtype = torch .int64 )
2697+ user_tag_dst = torch .tensor ([], dtype = torch .int64 )
2698+
2699+ num_nodes_dict = {
2700+ "user" : 5 ,
2701+ "item" : 10 ,
2702+ "tag" : 0 ,
2703+ }
2704+
2705+ data_dict = {
2706+ ("user" , "buys" , "item" ): (user_item_src , user_item_dst ),
2707+ ("user" , "likes" , "tag" ): (user_tag_src , user_tag_dst ),
2708+ }
2709+
2710+ g = dgl .heterograph (data_dict , num_nodes_dict = num_nodes_dict )
2711+
2712+ feat_dim = 16
2713+ node_features = {
2714+ "user" : torch .randn (num_nodes_dict ["user" ], feat_dim ),
2715+ "item" : torch .randn (num_nodes_dict ["item" ], feat_dim ),
2716+ "tag" : torch .randn (num_nodes_dict ["tag" ], feat_dim ),
2717+ }
2718+ edge_features = {
2719+ ("user" , "buys" , "item" ): torch .randn (g .num_edges (("user" , "buys" , "item" )), feat_dim ),
2720+ ("user" , "likes" , "tag" ): torch .randn (g .num_edges (("user" , "likes" , "tag" )), feat_dim ),
2721+ }
2722+
2723+ # Test GATConv with zero nodes in "tag" type
2724+ conv = nn .HeteroGraphConv ({
2725+ ("user" , "buys" , "item" ): nn .GATConv (16 , 2 , num_heads = 2 , residual = residual ),
2726+ ("user" , "likes" , "tag" ): nn .GATConv (16 , 2 , num_heads = 2 , residual = residual ),
2727+ }, aggregate = "sum" )
2728+ out = conv (g , node_features )
2729+ assert out ["item" ].shape == (10 , 2 , 2 )
2730+ assert out ["tag" ].shape == (0 , 2 , 2 )
2731+ assert "user" not in out
2732+
2733+ # Test GATv2Conv with zero nodes in "tag" type
2734+ conv_v2 = nn .HeteroGraphConv ({
2735+ ("user" , "buys" , "item" ): nn .GATv2Conv (16 , 2 , num_heads = 2 , residual = residual ),
2736+ ("user" , "likes" , "tag" ): nn .GATv2Conv (16 , 2 , num_heads = 2 , residual = residual ),
2737+ }, aggregate = "sum" )
2738+ out_v2 = conv_v2 (g , node_features )
2739+ assert out_v2 ["item" ].shape == (10 , 2 , 2 )
2740+ assert out_v2 ["tag" ].shape == (0 , 2 , 2 )
2741+ assert "user" not in out_v2
2742+
2743+ # Test EdgeGATConv with zero nodes in "tag" type
2744+ edge_conv = nn .HeteroGraphConv ({
2745+ ("user" , "buys" , "item" ): nn .EdgeGATConv (16 , 16 , 2 , num_heads = 2 , residual = residual ),
2746+ ("user" , "likes" , "tag" ): nn .EdgeGATConv (16 , 16 , 2 , num_heads = 2 , residual = residual ),
2747+ }, aggregate = "sum" )
2748+ mod_kwargs = {
2749+ "buys" : {"edge_feat" : edge_features [("user" , "buys" , "item" )]},
2750+ "likes" : {"edge_feat" : edge_features [("user" , "likes" , "tag" )]},
2751+ }
2752+ out_edge = edge_conv (g , node_features , mod_kwargs = mod_kwargs )
2753+ assert out_edge ["item" ].shape == (10 , 2 , 2 )
2754+ assert out_edge ["tag" ].shape == (0 , 2 , 2 )
2755+ assert "user" not in out_edge
0 commit comments