Open
Description
Hi fbgemm team, while using torchrec EmbeddingCollection
with adam optimizer, I found out that ec.fused_optimizer.state_dict() returns nothing but momentum tensors. lr
, decay
etc. which are usually accessible with normal torch.optim.Adam are gone.
See thread
I think the problem is here.
import os
import sys
sys.path.append(os.path.abspath('/home/scratch.junzhang_sw/workspace/torchrec'))
import torch
import torchrec
import torch.distributed as dist
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
dist.init_process_group(backend="nccl")
ebc = torchrec.EmbeddingCollection(
device=torch.device("meta"),
tables=[
torchrec.EmbeddingConfig(
name="product_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["product"],
),
torchrec.EmbeddingConfig(
name="user_table",
embedding_dim=64,
num_embeddings=4096,
feature_names=["user"],
)
]
)
from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward
apply_optimizer_in_backward(
optimizer_class=torch.optim.Adam,
params=ebc.parameters(),
optimizer_kwargs={"lr": 0.02},
)
from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.embedding import EmbeddingCollectionSharder
sharder = EmbeddingCollectionSharder(
# qcomm_codecs_registry=get_qcomm_codecs_registry(
# qcomms_config=QCommsConfig(
# forward_precision=CommType.FP16,
# backward_precision=CommType.BF16,
use_index_dedup=True,
# )
# )
)
dp_rank = dist.get_rank()
model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device("cuda"))
mb = torchrec.KeyedJaggedTensor(
keys = ["product", "user"],
values = torch.tensor([101, 201, 101, 404, 404, 606, 606, 606]).cuda(),
lengths = torch.tensor([2, 0, 1, 1, 1, 3], dtype=torch.int64).cuda(),
)
import pdb;pdb.set_trace()
ret = model(mb) # => this is awaitable
product = ret['product'] # implicitly call awaitable.wait()
# import pdb;pdb.set_trace()
Above model gives me optimizer state like:
>>> model.fused_optimizer.state_dict()['state']['embeddings.product_table.weight'].keys()
dict_keys(['product_table.momentum1', 'product_table.exp_avg_sq'])
# only `state` key.value, there are no param_groups that contain the lr, beta1 etc.
Those metadata are mandatory when I need to dump and reload the model.
Metadata
Assignees
Labels
No labels