13
13
# limitations under the License.
14
14
# ==============================================================================
15
15
"""Tests for dgi."""
16
+ import os
17
+
18
+ from absl .testing import parameterized
16
19
import tensorflow as tf
20
+ import tensorflow .__internal__ .distribute as tfdistribute
21
+ import tensorflow .__internal__ .test as tftest
17
22
import tensorflow_gnn as tfgnn
18
23
19
24
from tensorflow_gnn .runner import orchestration
42
47
""" % tfgnn .HIDDEN_STATE
43
48
44
49
45
- class DeepGraphInfomaxTest (tf .test .TestCase ):
46
-
50
+ def _all_eager_distributed_strategy_combinations ():
51
+ strategies = [
52
+ # MirroredStrategy
53
+ tfdistribute .combinations .mirrored_strategy_with_gpu_and_cpu ,
54
+ tfdistribute .combinations .mirrored_strategy_with_one_cpu ,
55
+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
56
+ """ # MultiWorkerMirroredStrategy
57
+ tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
58
+ tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
59
+ # TPUStrategy
60
+ tfdistribute.combinations.tpu_strategy,
61
+ tfdistribute.combinations.tpu_strategy_one_core,
62
+ tfdistribute.combinations.tpu_strategy_packed_var,
63
+ # ParameterServerStrategy
64
+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
65
+ tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
66
+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
67
+ tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
68
+ ]
69
+ return tftest .combinations .combine (distribution = strategies )
70
+
71
+
72
+ class DeepGraphInfomaxTest (tf .test .TestCase , parameterized .TestCase ):
73
+
74
+ global_batch_size = 2
47
75
gtspec = tfgnn .create_graph_spec_from_schema_pb (tfgnn .parse_schema (SCHEMA ))
48
- task = dgi .DeepGraphInfomax ("node" , seed = 8191 )
76
+ seed = 8191
77
+ task = dgi .DeepGraphInfomax (
78
+ "node" , global_batch_size = global_batch_size , seed = seed )
79
+
80
+ def get_graph_tensor (self ):
81
+ gt = tfgnn .GraphTensor .from_pieces (
82
+ node_sets = {
83
+ "node" :
84
+ tfgnn .NodeSet .from_fields (
85
+ features = {
86
+ tfgnn .HIDDEN_STATE :
87
+ tf .convert_to_tensor ([[1. , 2. , 3. , 4. ],
88
+ [11. , 11. , 11. , 11. ],
89
+ [19. , 19. , 19. , 19. ]])
90
+ },
91
+ sizes = tf .convert_to_tensor ([3 ])),
92
+ },
93
+ edge_sets = {
94
+ "edge" :
95
+ tfgnn .EdgeSet .from_fields (
96
+ sizes = tf .convert_to_tensor ([2 ]),
97
+ adjacency = tfgnn .Adjacency .from_indices (
98
+ ("node" , tf .convert_to_tensor ([0 , 1 ], dtype = tf .int32 )),
99
+ ("node" , tf .convert_to_tensor ([2 , 0 ], dtype = tf .int32 )),
100
+ )),
101
+ })
102
+ return gt
49
103
50
104
def build_model (self ):
51
105
graph = inputs = tf .keras .layers .Input (type_spec = self .gtspec )
@@ -56,7 +110,9 @@ def build_model(self):
56
110
"edge" ,
57
111
tfgnn .TARGET ,
58
112
feature_name = tfgnn .HIDDEN_STATE )
59
- messages = tf .keras .layers .Dense (16 )(values )
113
+ messages = tf .keras .layers .Dense (
114
+ 8 , kernel_initializer = tf .constant_initializer (1. ))(
115
+ values )
60
116
61
117
pooled = tfgnn .pool_edges_to_node (
62
118
graph ,
@@ -67,7 +123,9 @@ def build_model(self):
67
123
h_old = graph .node_sets ["node" ].features [tfgnn .HIDDEN_STATE ]
68
124
69
125
h_next = tf .keras .layers .Concatenate ()((pooled , h_old ))
70
- h_next = tf .keras .layers .Dense (8 )(h_next )
126
+ h_next = tf .keras .layers .Dense (
127
+ 4 , kernel_initializer = tf .constant_initializer (1. ))(
128
+ h_next )
71
129
72
130
graph = graph .replace_features (
73
131
node_sets = {"node" : {
@@ -87,30 +145,71 @@ def test_adapt(self):
87
145
feature_name = tfgnn .HIDDEN_STATE )(model (gt ))
88
146
actual = adapted (gt )
89
147
90
- self .assertAllClose (actual , expected )
148
+ self .assertAllClose (actual , expected , rtol = 1e-04 , atol = 1e-04 )
91
149
92
150
def test_fit (self ):
93
- gt = tfgnn . random_graph_tensor (self .gtspec )
94
- ds = tf . data . Dataset . from_tensors ( gt ). repeat ( 8 )
95
- ds = ds . batch ( 2 ). map ( tfgnn .GraphTensor .merge_batch_to_components )
151
+ ds = tf . data . Dataset . from_tensors (self .get_graph_tensor ()). repeat ( 8 )
152
+ ds = ds . batch ( self . global_batch_size ). map (
153
+ tfgnn .GraphTensor .merge_batch_to_components )
96
154
155
+ tf .random .set_seed (self .seed )
97
156
model = self .task .adapt (self .build_model ())
98
157
model .compile ()
99
158
100
159
def get_loss ():
160
+ tf .random .set_seed (self .seed )
101
161
values = model .evaluate (ds )
102
162
return dict (zip (model .metrics_names , values ))["loss" ]
103
163
104
164
before = get_loss ()
105
165
model .fit (ds )
106
166
after = get_loss ()
167
+ self .assertAllClose (before , 21754138.0 , rtol = 1e-04 , atol = 1e-04 )
168
+ self .assertAllClose (after , 16268301.0 , rtol = 1e-04 , atol = 1e-04 )
169
+
170
+ @tfdistribute .combinations .generate (
171
+ tftest .combinations .combine (distribution = [
172
+ tfdistribute .combinations .mirrored_strategy_with_one_gpu ,
173
+ tfdistribute .combinations .multi_worker_mirrored_2x1_gpu ,
174
+ ]))
175
+ def test_distributed (self , distribution ):
176
+ gt = self .get_graph_tensor ()
177
+
178
+ def dataset_fn (input_context = None , gt = gt ):
179
+ ds = tf .data .Dataset .from_tensors (gt ).repeat (8 )
180
+ if input_context :
181
+ batch_size = input_context .get_per_replica_batch_size (
182
+ self .global_batch_size )
183
+ else :
184
+ batch_size = self .global_batch_size
185
+ ds = ds .batch (batch_size ).map (tfgnn .GraphTensor .merge_batch_to_components )
186
+ return ds
187
+
188
+ with distribution .scope ():
189
+ tf .random .set_seed (self .seed )
190
+ model = self .task .adapt (self .build_model ())
191
+ model .compile ()
192
+
193
+ def get_loss ():
194
+ tf .random .set_seed (self .seed )
195
+ values = model .evaluate (
196
+ distribution .distribute_datasets_from_function (dataset_fn ), steps = 4 )
197
+ return dict (zip (model .metrics_names , values ))["loss" ]
198
+
199
+ before = get_loss ()
200
+ model .fit (
201
+ distribution .distribute_datasets_from_function (dataset_fn ),
202
+ steps_per_epoch = 4 )
203
+ after = get_loss ()
204
+ self .assertAllClose (before , 21754138.0 , rtol = 1e-04 , atol = 1e-04 )
205
+ self .assertAllClose (after , 16268301.0 , rtol = 1e-04 , atol = 1e-04 )
107
206
108
- self .assertAllClose ( before , 250.42036 , rtol = 1e-04 , atol = 1e-04 )
109
- self . assertAllClose ( after , 13.18533 , rtol = 1e-04 , atol = 1e-04 )
207
+ export_dir = os . path . join ( self .get_temp_dir (), "dropout-model" )
208
+ model . save ( export_dir )
110
209
111
210
def test_protocol (self ):
112
211
self .assertIsInstance (dgi .DeepGraphInfomax , orchestration .Task )
113
212
114
213
115
214
if __name__ == "__main__" :
116
- tf . test . main ()
215
+ tfdistribute . multi_process_runner . test_main ()
0 commit comments