Skip to content

optimizer state_dict misses some values #3597

Open
@JacoCheung

Description

Hi fbgemm team, while using torchrec EmbeddingCollection with adam optimizer, I found out that ec.fused_optimizer.state_dict() returns nothing but momentum tensors. lr, decay etc. which are usually accessible with normal torch.optim.Adam are gone.

See thread

I think the problem is here.


import os
import sys
sys.path.append(os.path.abspath('/home/scratch.junzhang_sw/workspace/torchrec'))
import torch
import torchrec
import torch.distributed as dist

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")
ebc = torchrec.EmbeddingCollection(
    device=torch.device("meta"),
    tables=[
        torchrec.EmbeddingConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
        ),
        torchrec.EmbeddingConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
        )
    ]
)

from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward

apply_optimizer_in_backward(
    optimizer_class=torch.optim.Adam,
    params=ebc.parameters(),
    optimizer_kwargs={"lr": 0.02},
)

from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embedding import EmbeddingCollectionSharder

sharder = EmbeddingCollectionSharder(
    # qcomm_codecs_registry=get_qcomm_codecs_registry(
    #         qcomms_config=QCommsConfig(
    #             forward_precision=CommType.FP16,
    #             backward_precision=CommType.BF16,
                  use_index_dedup=True,
    #         )
    #     )
)
dp_rank = dist.get_rank()
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device("cuda"))
mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 201, 101, 404, 404, 606, 606, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 3], dtype=torch.int64).cuda(),
)
import pdb;pdb.set_trace()
ret = model(mb) # => this is awaitable
product = ret['product'] # implicitly call awaitable.wait()
# import pdb;pdb.set_trace()

Above model gives me optimizer state like:

>>> model.fused_optimizer.state_dict()['state']['embeddings.product_table.weight'].keys()
dict_keys(['product_table.momentum1', 'product_table.exp_avg_sq']) 
# only `state` key.value, there are no param_groups that contain the lr, beta1 etc.

Those metadata are mandatory when I need to dump and reload the model.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions