Skip to content

Strange behavior of saving sharded trainstate in GCP. #660

Open
@chiamp

Description

@chiamp

A user posted in the Flax discussions about an orbax discrepancy between different zones in GCE. Do different zones have different orbax versions?

==================================================================

what happened

When I save my sharded state in asia-northeast3-a in GCE with orbax, the orbax create /tmp/orbax_ckpt/0/_sharding file which starts with

{"dropout_rng":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","opt_state.0.0.count":"{\"sharding_type\": \"NamedSharding\", \"shape\": [2, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

My sharded state has "dropout_rng" state, so above file make sense.

However, when I run same script in other region like asia-southeast1-b, the orbax create _sharding file without proper layer names, for example,

{"ZHJvcG91dF9ybmc=":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}","b3B0X3N0YXRlLjAuMC5jb3VudA==":"{\"sharding_type\": \"NamedSharding\", \"shape\": [1, 1], \"axis_names\": [\"data\", \"model\"], \"partition_spec\": []}",
...

Theory

I doubt that this is related to OCDBT, because the only difference in between terminal outputs is ocdbt is intitialized in asia-northeast3-a but the other regions are not having this message
type_handlers.py:223] OCDBT is initialized successfully..

I checked tensorstore==0.1.51 in all region.

Anyone can help me please?

Thank you.

Originally posted by @sw32-seo in google/flax#3538

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions