Description
Describe the problem.
Reloading variables for optimizer with partitioning present fails to actually load variables if load is done first (deferred restoration. The code runs, but accumulator state is same as not loading at all. This is more specific issue follow up to this one.
The code linked here with 3 adjustments can be used to reproduce this. That code does,
toy_model2.build(input_shape=(None, 1))
toy_model2.optimizer.build(toy_model2.trainable_variables) # type: ignore
print("Loading weights...")
toy_model2.load_weights(weights_path).assert_consumed()
Instead change it to do,
toy_model2.load_weights(weights_path).assert_consumed()
toy_model2.build(input_shape=(None, 1))
toy_model2.optimizer.build(toy_model2.trainable_variables) # type: ignore
print("Loading weights...")
Also keep partitioner in second model so change,
strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver)
to
strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver, variable_partitioner=partitioner)
Lastly we need to train first model for a couple steps. Something like
toy_model.fit(tf.constant([[1], [2]]]), tf.constant([[0.0], [1.0]]), num_epochs=1, steps_per_epoch=1)
If weights are loaded before building, they fail to restore silently for partitioned optimizers. When no partitioner is used it works fine.
Lastly PS is not necessary for this issue. Just that PS is common way of using partitioner, but if you use tf.variable_creator_scope that produces sharded variables you can cause same issue without PS.
Standalone code to reproduce the issue.
from typing import Any, Mapping
import json
import os
import tempfile
from multiprocessing import Process
import tensorflow as tf
from portpicker import pick_unused_port
__spec__ = None
def create_tf_configs(worker_count: int, ps_count: int):
"""Create TF_CONFIGs for a cluster."""
cluster_dict: dict[str, list[str]] = {}
if worker_count:
cluster_dict["worker"] = [f"localhost:{pick_unused_port()}" for _ in range(worker_count)]
if ps_count:
cluster_dict["ps"] = [f"localhost:{pick_unused_port()}" for _ in range(ps_count)]
cluster_dict["chief"] = [f"localhost:{pick_unused_port()}"]
tf_configs = []
for i in range(worker_count):
tf_configs.append({"cluster": cluster_dict, "task": {"type": "worker", "index": i}})
for i in range(ps_count):
tf_configs.append({"cluster": cluster_dict, "task": {"type": "ps", "index": i}})
tf_configs.append({"cluster": cluster_dict, "task": {"type": "chief", "index": 0}})
return tf_configs
def _create_process(tf_config: Mapping[str, Any]):
name = tf_config["task"]["type"] + "_" + str(tf_config["task"]["index"])
print(f"Starting {name} process...")
os.environ["TF_CONFIG"] = json.dumps(tf_config)
p = Process(target=run)
p.start()
def run():
resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()
task_type = resolver.task_type
if task_type in ("worker", "ps"):
print("Starting server...")
server = tf.distribute.Server(
resolver.cluster_spec(),
job_name=resolver.task_type,
task_index=resolver.task_id,
protocol=resolver.rpc_layer,
start=True,
)
server.join()
partitioner = tf.distribute.experimental.partitioners.MaxSizePartitioner(max_shard_bytes=100 * 16 * 4)
strategy = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver=resolver, variable_partitioner=partitioner
)
print("Building model...")
with strategy.scope():
toy_model = tf.keras.Sequential(
[tf.keras.layers.Embedding(100, 32), tf.keras.layers.Dense(1, activation="sigmoid")]
)
toy_model.compile(loss="binary_crossentropy", optimizer=tf.optimizers.experimental.Adam())
toy_model.build(input_shape=(None, 1))
toy_model.optimizer.build(toy_model.trainable_variables) # type: ignore
toy_model.fit(tf.constant([[1], [2]]]), tf.constant([[0.0], [1.0]]), num_epochs=1, steps_per_epoch=1)
print(toy_model.optimizer.variables)
print("Saving weights...")
temp_dir = tempfile.gettempdir()
weights_path = os.path.join(temp_dir, "model_weights")
toy_model.save_weights(weights_path)
strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver, variable_partitioner=partitioner)
with strategy2.scope():
toy_model2 = tf.keras.Sequential(
[tf.keras.layers.Embedding(100, 32), tf.keras.layers.Dense(1, activation="sigmoid")]
)
toy_model2.compile(loss="binary_crossentropy", optimizer=tf.optimizers.experimental.Adam())
toy_model2.load_weights(weights_path)
toy_model2.build(input_shape=(None, 1))
toy_model2.optimizer.build(toy_model2.trainable_variables) # type: ignore
print(toy_model2.optimizer.variables) # Inconsistent with printed optimizer variables prior.
print("Loading weights...")
print("Done!")
def main():
tf_configs = create_tf_configs(2, 1)
chief_config = tf_configs[-1]
for tf_config in tf_configs[:-1]:
_create_process(tf_config)
os.environ["TF_CONFIG"] = json.dumps(chief_config)
run()
if __name__ == "__main__":
main()