Skip to content

Partitioned Optimizer Variable Reloading issue #838

Open
@hmc-cs-mdrissi

Description

@hmc-cs-mdrissi

Describe the problem.

Reloading variables for optimizer with partitioning present fails to actually load variables if load is done first (deferred restoration. The code runs, but accumulator state is same as not loading at all. This is more specific issue follow up to this one.

The code linked here with 3 adjustments can be used to reproduce this. That code does,

toy_model2.build(input_shape=(None, 1))
toy_model2.optimizer.build(toy_model2.trainable_variables)  # type: ignore
print("Loading weights...")
toy_model2.load_weights(weights_path).assert_consumed()

Instead change it to do,

toy_model2.load_weights(weights_path).assert_consumed()
toy_model2.build(input_shape=(None, 1))
toy_model2.optimizer.build(toy_model2.trainable_variables)  # type: ignore
print("Loading weights...")

Also keep partitioner in second model so change,

strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver)

to

strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver, variable_partitioner=partitioner)

Lastly we need to train first model for a couple steps. Something like

toy_model.fit(tf.constant([[1], [2]]]), tf.constant([[0.0], [1.0]]), num_epochs=1, steps_per_epoch=1)

If weights are loaded before building, they fail to restore silently for partitioned optimizers. When no partitioner is used it works fine.

Lastly PS is not necessary for this issue. Just that PS is common way of using partitioner, but if you use tf.variable_creator_scope that produces sharded variables you can cause same issue without PS.

Standalone code to reproduce the issue.

from typing import Any, Mapping

import json
import os
import tempfile
from multiprocessing import Process

import tensorflow as tf

from portpicker import pick_unused_port

__spec__ = None


def create_tf_configs(worker_count: int, ps_count: int):
    """Create TF_CONFIGs for a cluster."""
    cluster_dict: dict[str, list[str]] = {}
    if worker_count:
        cluster_dict["worker"] = [f"localhost:{pick_unused_port()}" for _ in range(worker_count)]
    if ps_count:
        cluster_dict["ps"] = [f"localhost:{pick_unused_port()}" for _ in range(ps_count)]

    cluster_dict["chief"] = [f"localhost:{pick_unused_port()}"]

    tf_configs = []
    for i in range(worker_count):
        tf_configs.append({"cluster": cluster_dict, "task": {"type": "worker", "index": i}})

    for i in range(ps_count):
        tf_configs.append({"cluster": cluster_dict, "task": {"type": "ps", "index": i}})

    tf_configs.append({"cluster": cluster_dict, "task": {"type": "chief", "index": 0}})

    return tf_configs


def _create_process(tf_config: Mapping[str, Any]):
    name = tf_config["task"]["type"] + "_" + str(tf_config["task"]["index"])

    print(f"Starting {name} process...")
    os.environ["TF_CONFIG"] = json.dumps(tf_config)
    p = Process(target=run)
    p.start()


def run():
    resolver = tf.distribute.cluster_resolver.TFConfigClusterResolver()

    task_type = resolver.task_type
    if task_type in ("worker", "ps"):
        print("Starting server...")
        server = tf.distribute.Server(
            resolver.cluster_spec(),
            job_name=resolver.task_type,
            task_index=resolver.task_id,
            protocol=resolver.rpc_layer,
            start=True,
        )
        server.join()

    partitioner = tf.distribute.experimental.partitioners.MaxSizePartitioner(max_shard_bytes=100 * 16 * 4)
    strategy = tf.distribute.experimental.ParameterServerStrategy(
        cluster_resolver=resolver, variable_partitioner=partitioner
    )

    print("Building model...")
    with strategy.scope():
        toy_model = tf.keras.Sequential(
            [tf.keras.layers.Embedding(100, 32), tf.keras.layers.Dense(1, activation="sigmoid")]
        )
        toy_model.compile(loss="binary_crossentropy", optimizer=tf.optimizers.experimental.Adam())
        toy_model.build(input_shape=(None, 1))
        toy_model.optimizer.build(toy_model.trainable_variables)  # type: ignore
        toy_model.fit(tf.constant([[1], [2]]]), tf.constant([[0.0], [1.0]]), num_epochs=1, steps_per_epoch=1)
        print(toy_model.optimizer.variables) 


    print("Saving weights...")
    temp_dir = tempfile.gettempdir()
    weights_path = os.path.join(temp_dir, "model_weights")
    toy_model.save_weights(weights_path)

    strategy2 = tf.distribute.experimental.ParameterServerStrategy(cluster_resolver=resolver, variable_partitioner=partitioner)
    with strategy2.scope():
        toy_model2 = tf.keras.Sequential(
            [tf.keras.layers.Embedding(100, 32), tf.keras.layers.Dense(1, activation="sigmoid")]
        )
        toy_model2.compile(loss="binary_crossentropy", optimizer=tf.optimizers.experimental.Adam())
        toy_model2.load_weights(weights_path)
        toy_model2.build(input_shape=(None, 1))
        toy_model2.optimizer.build(toy_model2.trainable_variables)  # type: ignore
        print(toy_model2.optimizer.variables) # Inconsistent with printed optimizer variables prior.
        print("Loading weights...")
        
    print("Done!")


def main():
    tf_configs = create_tf_configs(2, 1)
    chief_config = tf_configs[-1]
    for tf_config in tf_configs[:-1]:
        _create_process(tf_config)

    os.environ["TF_CONFIG"] = json.dumps(chief_config)
    run()


if __name__ == "__main__":
    main()

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions