Skip to content

Commit 1b152b2

Browse files
Neslihanstensorflower-gardener
authored andcommitted
Update add_loss calls in DGI task to reduce them by dividing with the global_batch_size before passing to Keras.
PiperOrigin-RevId: 487897250
1 parent d7a9659 commit 1b152b2

File tree

3 files changed

+145
-26
lines changed

3 files changed

+145
-26
lines changed

Diff for: tensorflow_gnn/runner/tasks/BUILD

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "pytype_strict_library")
22
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "py_strict_test")
3+
load("@tensorflow_gnn//tensorflow_gnn:tensorflow_gnn.bzl", "distribute_py_test")
34

45
licenses(["notice"])
56

@@ -40,13 +41,16 @@ pytype_strict_library(
4041
],
4142
)
4243

43-
py_strict_test(
44+
distribute_py_test(
4445
name = "dgi_test",
4546
srcs = ["dgi_test.py"],
4647
srcs_version = "PY3",
48+
xla_enable_strict_auto_jit = False,
4749
deps = [
4850
":dgi",
51+
"//:expect_absl_installed",
4952
"//:expect_tensorflow_installed",
53+
"//:expect_tensorflow_installed:tensorflow_no_contrib",
5054
"//tensorflow_gnn",
5155
"//tensorflow_gnn/runner:orchestration",
5256
],

Diff for: tensorflow_gnn/runner/tasks/dgi.py

+29-13
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
class AddLossDeepGraphInfomax(tf.keras.layers.Layer):
2424
""""A bilinear layer with losses and metrics for Deep Graph Infomax."""
2525

26-
def __init__(self, units: int):
26+
def __init__(self, units: int, global_batch_size: int, **kwargs):
2727
"""Builds the bilinear layer weights.
2828
2929
Args:
3030
units: Units for the bilinear layer.
31+
global_batch_size: Global batch size to compute the average loss.
32+
**kwargs: Extra arguments needed for serialization.
3133
"""
32-
super().__init__()
34+
super().__init__(**kwargs)
3335
self._bilinear = tf.keras.layers.Dense(units, use_bias=False)
36+
self._global_batch_size = global_batch_size
3437

3538
def get_config(self) -> Mapping[Any, Any]:
3639
"""Returns the config of the layer.
@@ -58,13 +61,19 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
5861
y_clean, y_corrupted = inputs
5962
# Summary
6063
summary = tf.math.reduce_mean(y_clean, axis=0, keepdims=True)
64+
per_replica_batch_size = (
65+
self._global_batch_size //
66+
tf.distribute.get_strategy().num_replicas_in_sync)
6167
# Clean losses and metrics
6268
logits_clean = tf.matmul(y_clean, self._bilinear(summary), transpose_b=True)
63-
self.add_loss(tf.keras.losses.BinaryCrossentropy(
69+
loss_clean = tf.keras.losses.BinaryCrossentropy(
6470
from_logits=True,
65-
name="binary_crossentropy_clean")(
66-
tf.ones_like(logits_clean),
67-
logits_clean))
71+
name="binary_crossentropy_clean",
72+
reduction=tf.keras.losses.Reduction.NONE)
73+
self.add_loss(
74+
tf.nn.compute_average_loss(
75+
loss_clean(tf.ones_like(logits_clean), logits_clean),
76+
global_batch_size=per_replica_batch_size))
6877
self.add_metric(
6978
tf.keras.metrics.binary_crossentropy(
7079
tf.ones_like(logits_clean),
@@ -81,11 +90,14 @@ def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor:
8190
y_corrupted,
8291
self._bilinear(summary),
8392
transpose_b=True)
84-
self.add_loss(tf.keras.losses.BinaryCrossentropy(
93+
loss_corrupted = tf.keras.losses.BinaryCrossentropy(
8594
from_logits=True,
86-
name="binary_crossentropy_corrupted")(
87-
tf.zeros_like(logits_corrupted),
88-
logits_corrupted))
95+
name="binary_crossentropy_corrupted",
96+
reduction=tf.keras.losses.Reduction.NONE)
97+
self.add_loss(
98+
tf.nn.compute_average_loss(
99+
loss_corrupted(tf.zeros_like(logits_corrupted), logits_corrupted),
100+
global_batch_size=per_replica_batch_size))
89101
self.add_metric(
90102
tf.keras.metrics.binary_crossentropy(
91103
tf.zeros_like(logits_corrupted),
@@ -125,18 +137,21 @@ class DeepGraphInfomax:
125137
def __init__(self,
126138
node_set_name: str,
127139
*,
140+
global_batch_size: int,
128141
state_name: str = tfgnn.HIDDEN_STATE,
129142
seed: Optional[int] = None):
130143
"""Captures arguments for the task.
131144
132145
Args:
133146
node_set_name: The node set for activations.
147+
global_batch_size: Global batch size(not per-replica) for the training.
134148
state_name: The state name of any activations.
135149
seed: A seed for corrupted representations.
136150
"""
137151
self._state_name = state_name
138152
self._node_set_name = node_set_name
139153
self._seed = seed
154+
self._global_batch_size = global_batch_size
140155

141156
def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
142157
"""Adapt a `tf.keras.Model` for Deep Graph Infomax.
@@ -164,15 +179,16 @@ def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
164179
feature_name=self._state_name)(model.output)
165180

166181
# Corrupted representations: shuffling, model application and readout
167-
shuffled = tfgnn.shuffle_features_globally(model.input)
182+
shuffled = tfgnn.shuffle_features_globally(model.input, seed=self._seed)
168183
y_corrupted = tfgnn.keras.layers.ReadoutFirstNode(
169184
node_set_name=self._node_set_name,
170185
feature_name=self._state_name)(model(shuffled))
171186

172187
return tf.keras.Model(
173188
model.input,
174-
AddLossDeepGraphInfomax(
175-
y_clean.get_shape()[-1])((y_clean, y_corrupted)))
189+
AddLossDeepGraphInfomax(y_clean.get_shape()[-1],
190+
self._global_batch_size)(
191+
(y_clean, y_corrupted)))
176192

177193
def preprocess(self, gt: tfgnn.GraphTensor) -> tfgnn.GraphTensor:
178194
"""Returns the input GraphTensor."""

Diff for: tensorflow_gnn/runner/tasks/dgi_test.py

+111-12
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
"""Tests for dgi."""
16+
import os
17+
18+
from absl.testing import parameterized
1619
import tensorflow as tf
20+
import tensorflow.__internal__.distribute as tfdistribute
21+
import tensorflow.__internal__.test as tftest
1722
import tensorflow_gnn as tfgnn
1823

1924
from tensorflow_gnn.runner import orchestration
@@ -42,10 +47,59 @@
4247
""" % tfgnn.HIDDEN_STATE
4348

4449

45-
class DeepGraphInfomaxTest(tf.test.TestCase):
46-
50+
def _all_eager_distributed_strategy_combinations():
51+
strategies = [
52+
# MirroredStrategy
53+
tfdistribute.combinations.mirrored_strategy_with_gpu_and_cpu,
54+
tfdistribute.combinations.mirrored_strategy_with_one_cpu,
55+
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
56+
""" # MultiWorkerMirroredStrategy
57+
tfdistribute.combinations.multi_worker_mirrored_2x1_cpu,
58+
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
59+
# TPUStrategy
60+
tfdistribute.combinations.tpu_strategy,
61+
tfdistribute.combinations.tpu_strategy_one_core,
62+
tfdistribute.combinations.tpu_strategy_packed_var,
63+
# ParameterServerStrategy
64+
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_cpu,
65+
tfdistribute.combinations.parameter_server_strategy_3worker_2ps_1gpu,
66+
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_cpu,
67+
tfdistribute.combinations.parameter_server_strategy_1worker_2ps_1gpu, """
68+
]
69+
return tftest.combinations.combine(distribution=strategies)
70+
71+
72+
class DeepGraphInfomaxTest(tf.test.TestCase, parameterized.TestCase):
73+
74+
global_batch_size = 2
4775
gtspec = tfgnn.create_graph_spec_from_schema_pb(tfgnn.parse_schema(SCHEMA))
48-
task = dgi.DeepGraphInfomax("node", seed=8191)
76+
seed = 8191
77+
task = dgi.DeepGraphInfomax(
78+
"node", global_batch_size=global_batch_size, seed=seed)
79+
80+
def get_graph_tensor(self):
81+
gt = tfgnn.GraphTensor.from_pieces(
82+
node_sets={
83+
"node":
84+
tfgnn.NodeSet.from_fields(
85+
features={
86+
tfgnn.HIDDEN_STATE:
87+
tf.convert_to_tensor([[1., 2., 3., 4.],
88+
[11., 11., 11., 11.],
89+
[19., 19., 19., 19.]])
90+
},
91+
sizes=tf.convert_to_tensor([3])),
92+
},
93+
edge_sets={
94+
"edge":
95+
tfgnn.EdgeSet.from_fields(
96+
sizes=tf.convert_to_tensor([2]),
97+
adjacency=tfgnn.Adjacency.from_indices(
98+
("node", tf.convert_to_tensor([0, 1], dtype=tf.int32)),
99+
("node", tf.convert_to_tensor([2, 0], dtype=tf.int32)),
100+
)),
101+
})
102+
return gt
49103

50104
def build_model(self):
51105
graph = inputs = tf.keras.layers.Input(type_spec=self.gtspec)
@@ -56,7 +110,9 @@ def build_model(self):
56110
"edge",
57111
tfgnn.TARGET,
58112
feature_name=tfgnn.HIDDEN_STATE)
59-
messages = tf.keras.layers.Dense(16)(values)
113+
messages = tf.keras.layers.Dense(
114+
8, kernel_initializer=tf.constant_initializer(1.))(
115+
values)
60116

61117
pooled = tfgnn.pool_edges_to_node(
62118
graph,
@@ -67,7 +123,9 @@ def build_model(self):
67123
h_old = graph.node_sets["node"].features[tfgnn.HIDDEN_STATE]
68124

69125
h_next = tf.keras.layers.Concatenate()((pooled, h_old))
70-
h_next = tf.keras.layers.Dense(8)(h_next)
126+
h_next = tf.keras.layers.Dense(
127+
4, kernel_initializer=tf.constant_initializer(1.))(
128+
h_next)
71129

72130
graph = graph.replace_features(
73131
node_sets={"node": {
@@ -87,30 +145,71 @@ def test_adapt(self):
87145
feature_name=tfgnn.HIDDEN_STATE)(model(gt))
88146
actual = adapted(gt)
89147

90-
self.assertAllClose(actual, expected)
148+
self.assertAllClose(actual, expected, rtol=1e-04, atol=1e-04)
91149

92150
def test_fit(self):
93-
gt = tfgnn.random_graph_tensor(self.gtspec)
94-
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
95-
ds = ds.batch(2).map(tfgnn.GraphTensor.merge_batch_to_components)
151+
ds = tf.data.Dataset.from_tensors(self.get_graph_tensor()).repeat(8)
152+
ds = ds.batch(self.global_batch_size).map(
153+
tfgnn.GraphTensor.merge_batch_to_components)
96154

155+
tf.random.set_seed(self.seed)
97156
model = self.task.adapt(self.build_model())
98157
model.compile()
99158

100159
def get_loss():
160+
tf.random.set_seed(self.seed)
101161
values = model.evaluate(ds)
102162
return dict(zip(model.metrics_names, values))["loss"]
103163

104164
before = get_loss()
105165
model.fit(ds)
106166
after = get_loss()
167+
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
168+
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)
169+
170+
@tfdistribute.combinations.generate(
171+
tftest.combinations.combine(distribution=[
172+
tfdistribute.combinations.mirrored_strategy_with_one_gpu,
173+
tfdistribute.combinations.multi_worker_mirrored_2x1_gpu,
174+
]))
175+
def test_distributed(self, distribution):
176+
gt = self.get_graph_tensor()
177+
178+
def dataset_fn(input_context=None, gt=gt):
179+
ds = tf.data.Dataset.from_tensors(gt).repeat(8)
180+
if input_context:
181+
batch_size = input_context.get_per_replica_batch_size(
182+
self.global_batch_size)
183+
else:
184+
batch_size = self.global_batch_size
185+
ds = ds.batch(batch_size).map(tfgnn.GraphTensor.merge_batch_to_components)
186+
return ds
187+
188+
with distribution.scope():
189+
tf.random.set_seed(self.seed)
190+
model = self.task.adapt(self.build_model())
191+
model.compile()
192+
193+
def get_loss():
194+
tf.random.set_seed(self.seed)
195+
values = model.evaluate(
196+
distribution.distribute_datasets_from_function(dataset_fn), steps=4)
197+
return dict(zip(model.metrics_names, values))["loss"]
198+
199+
before = get_loss()
200+
model.fit(
201+
distribution.distribute_datasets_from_function(dataset_fn),
202+
steps_per_epoch=4)
203+
after = get_loss()
204+
self.assertAllClose(before, 21754138.0, rtol=1e-04, atol=1e-04)
205+
self.assertAllClose(after, 16268301.0, rtol=1e-04, atol=1e-04)
107206

108-
self.assertAllClose(before, 250.42036, rtol=1e-04, atol=1e-04)
109-
self.assertAllClose(after, 13.18533, rtol=1e-04, atol=1e-04)
207+
export_dir = os.path.join(self.get_temp_dir(), "dropout-model")
208+
model.save(export_dir)
110209

111210
def test_protocol(self):
112211
self.assertIsInstance(dgi.DeepGraphInfomax, orchestration.Task)
113212

114213

115214
if __name__ == "__main__":
116-
tf.test.main()
215+
tfdistribute.multi_process_runner.test_main()

0 commit comments

Comments
 (0)