Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sending using lengths to TBE instead of just offsets #2557

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,15 @@ def _prefetch_and_cached(
)


def _all_tables_are_quant_kernel(
tables: List[ShardedEmbeddingTable],
) -> bool:
"""
Return if all tables have quant compute kernel.
"""
return all(table.compute_kernel == EmbeddingComputeKernel.QUANT for table in tables)


# group tables by `DataType`, `PoolingType`, and `EmbeddingComputeKernel`.
def group_tables(
tables_per_rank: List[List[ShardedEmbeddingTable]],
Expand Down Expand Up @@ -489,6 +498,8 @@ def _group_tables_per_rank(
# Collect groups
groups = defaultdict(list)
grouping_keys = []
# Assumes all compute kernels within tables are the same
is_inference = _all_tables_are_quant_kernel(embedding_tables)
for table in embedding_tables:
bucketer = (
prefetch_cached_dim_bucketer
Expand All @@ -499,12 +510,16 @@ def _group_tables_per_rank(
_get_grouping_fused_params(table.fused_params, table.name) or {}
)
grouping_key = (
table.data_type,
table.data_type if not is_inference else None,
table.pooling,
table.has_feature_processor,
tuple(sorted(group_fused_params.items())),
_get_compute_kernel_type(table.compute_kernel),
bucketer.get_bucket(table.local_cols, table.data_type),
# TODO: Unit test to check if table.data_type affects table grouping
bucketer.get_bucket(
table.local_cols,
table.data_type,
),
_prefetch_and_cached(table),
)
# micromanage the order of we traverse the groups to ensure backwards compatibility
Expand Down
20 changes: 19 additions & 1 deletion torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pyre-strict

from typing import Any, Dict, Iterable, Optional
from typing import Any, Dict, Iterable, List, Optional

import torch

Expand All @@ -24,6 +24,10 @@
FUSED_PARAM_TBE_ROW_ALIGNMENT: str = "__register_tbe_row_alignment"
FUSED_PARAM_BOUNDS_CHECK_MODE: str = "__register_tbe_bounds_check_mode"

# Force lengths to offsets conversion before TBE lookup. Helps with performance
# with certain ways to split models.
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"


class TBEToRegisterMixIn:
def get_tbes_to_register(
Expand Down Expand Up @@ -68,6 +72,18 @@ def fused_param_bounds_check_mode(
return fused_params[FUSED_PARAM_BOUNDS_CHECK_MODE]


def fused_param_lengths_to_offsets_lookup(
fused_params: Optional[Dict[str, Any]]
) -> bool:
if (
fused_params is None
or FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP not in fused_params
):
return False
else:
return fused_params[FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP]


def is_fused_param_quant_state_dict_split_scale_bias(
fused_params: Optional[Dict[str, Any]]
) -> bool:
Expand All @@ -93,5 +109,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_TBE_ROW_ALIGNMENT)
if FUSED_PARAM_BOUNDS_CHECK_MODE in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)

return fused_params_for_tbe
196 changes: 131 additions & 65 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from torchrec.distributed.fused_params import (
fused_param_bounds_check_mode,
fused_param_lengths_to_offsets_lookup,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
Expand Down Expand Up @@ -171,6 +172,19 @@ def _unwrap_kjt_for_cpu(
return indices, offsets, None


@torch.fx.wrap
def _unwrap_kjt_lengths(
features: KeyedJaggedTensor,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
indices = features.values()
lengths = features.lengths()
return (
indices.int(),
lengths.int(),
features.weights_or_none(),
)


@torch.fx.wrap
def _unwrap_optional_tensor(
tensor: Optional[torch.Tensor],
Expand All @@ -180,6 +194,26 @@ def _unwrap_optional_tensor(
return tensor


class IntNBitTableBatchedEmbeddingBagsCodegenWithLength(
IntNBitTableBatchedEmbeddingBagsCodegen
):
def __init__(self, *args: Any, **kwargs: Dict[str, Any]) -> None:
super().__init__(*args, **kwargs)

# pyre-ignore Inconsistent override [14]
def forward(
self,
indices: torch.Tensor,
lengths: torch.Tensor,
per_sample_weights: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return super().forward(
indices,
torch.ops.fbgemm.asynchronous_complete_cumsum(lengths),
per_sample_weights,
)


class QuantBatchedEmbeddingBag(
BaseBatchedEmbeddingBag[
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
Expand All @@ -192,6 +226,7 @@ def __init__(
pg: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
fused_params: Optional[Dict[str, Any]] = None,
data_type_changed: bool = False,
) -> None:
super().__init__(config, pg, device)

Expand All @@ -216,40 +251,53 @@ def __init__(
self._runtime_device: torch.device = _get_runtime_device(device, config)
# 16 for CUDA, 1 for others like CPU and MTIA.
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (
IntNBitTableBatchedEmbeddingBagsCodegen(
embedding_specs=[
embedding_specs = []
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
):
embedding_specs.append(
(
table.name,
local_rows,
(
table.name,
local_rows,
(
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
location,
)
for local_rows, local_cols, table, location in zip(
self._local_rows,
self._local_cols,
config.embedding_tables,
managed,
)
],
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
local_cols
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(
# if data_type has changed, we want to default to the up-to-date config.data_type, instead of the embedding_tables which does not have the quantized data type
config.data_type
if data_type_changed
else table.data_type
),
location,
)
)

self.lengths_to_tbe: bool = fused_param_lengths_to_offsets_lookup(fused_params)

if self.lengths_to_tbe:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegenWithLength
else:
tbe_clazz = IntNBitTableBatchedEmbeddingBagsCodegen

self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = tbe_clazz(
embedding_specs=embedding_specs,
device=device,
pooling_mode=self._pooling,
feature_table_map=self._feature_table_map,
row_alignment=self._tbe_row_alignment,
uvm_host_mapped=True, # Use cudaHostAlloc for UVM CACHING to fix imbalance numa memory issue
bounds_check_mode=(
bounds_check_mode if bounds_check_mode else BoundsCheckMode.WARNING
),
feature_names_per_table=[
table.feature_names for table in config.embedding_tables
],
**(tbe_fused_params(fused_params) or {}),
)
if device is not None:
self._emb_module.initialize_weights()
Expand All @@ -268,44 +316,50 @@ def get_tbes_to_register(
) -> Dict[IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig]:
return {self._emb_module: self._config}

def _emb_module_forward(
self,
indices: torch.Tensor,
lengths_or_offsets: torch.Tensor,
weights: Optional[torch.Tensor],
) -> torch.Tensor:
kwargs = {"indices": indices}

if self._is_weighted:
kwargs["per_sample_weights"] = _unwrap_optional_tensor(weights)

if self.lengths_to_tbe:
kwargs["lengths"] = lengths_or_offsets
else:
kwargs["offsets"] = lengths_or_offsets

if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self._emb_module(**kwargs)
else:
return self._emb_module.forward(**kwargs)

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
lengths, offsets = None, None
if self._runtime_device.type == "cpu":
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)

if self._is_weighted:
weights = _unwrap_optional_tensor(per_sample_weights)
if self._emb_module_registered:
# Conditional call of .forward function for FX:
# emb_module() can go through FX only if emb_module is registered in named_modules (FX node call_module)
# emb_module.forward() does not require registering emb_module in named_modules (FX node call_function)
# For some post processing that requires TBE emb_module copied in fx.GraphModule we need to be call_module, as it will copies this module inside fx.GraphModule unchanged.
return self.emb_module(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
per_sample_weights=weights,
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
if self._emb_module_registered:
return self.emb_module(
indices=indices,
offsets=offsets,
)
if self.lengths_to_tbe:
indices, lengths, per_sample_weights = _unwrap_kjt_lengths(features)
else:
return self.emb_module.forward(
indices=indices,
offsets=offsets,
)
indices, offsets, per_sample_weights = _unwrap_kjt(features)

return self._emb_module_forward(
indices, lengths if lengths is not None else offsets, per_sample_weights
)

def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
Expand Down Expand Up @@ -359,8 +413,16 @@ def from_float(
)
device = next(iter(state_dict.values())).device

data_type_changed = False
# qconfig data type can be different from the config data type - if we are quantizing already sharded embeddings.
# This means the embedding_tables within GroupedEmbeddingConfig do not have the up-to-date data type - as they have not yet been quantized
if data_type != module.config.data_type:
data_type_changed = True
# We update the config to have the right data_type, sparse_type, and device. This update does not change the embedding_tables data type
config = _copy_config(module.config, data_type, sparse_type, device)
ret = QuantBatchedEmbeddingBag(config=config, device=device)
ret = QuantBatchedEmbeddingBag(
config=config, device=device, data_type_changed=data_type_changed
)

# pyre-ignore
quant_weight_list = _quantize_weight(state_dict, data_type)
Expand Down Expand Up @@ -411,7 +473,11 @@ def __init__(
if self._quant_state_dict_split_scale_bias
else table.embedding_dim
),
data_type_to_sparse_type(config.data_type),
(
data_type_to_sparse_type(config.data_type)
if config.data_type is not None
else data_type_to_sparse_type(table.data_type)
),
location,
)
for local_rows, local_cols, table, location in zip(
Expand Down
15 changes: 15 additions & 0 deletions torchrec/distributed/tests/test_embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,21 @@ def test_should_not_group_together(
)
return

# If both kernels are quantized, we assume this is inference which we no longer split by data_type
# So if other attributes are the same between the two tables (regardless of data type), we combine them
if (
tables[0].compute_kernel == EmbeddingComputeKernel.QUANT
and tables[1].compute_kernel == EmbeddingComputeKernel.QUANT
and tables[0].pooling == tables[1].pooling
and tables[0].has_feature_processor == tables[1].has_feature_processor
):

self.assertEqual(
sorted(_get_table_names_by_groups(tables)),
[["table_0", "table_1"]],
)
return

self.assertEqual(
sorted(_get_table_names_by_groups(tables)),
[["table_0"], ["table_1"]],
Expand Down
2 changes: 2 additions & 0 deletions torchrec/inference/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.fused_params import (
FUSED_PARAM_BOUNDS_CHECK_MODE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
)
Expand Down Expand Up @@ -82,6 +83,7 @@ def trim_torch_package_prefix_from_typename(typename: str) -> str:
FUSED_PARAM_REGISTER_TBE_BOOL: True,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS: True,
FUSED_PARAM_BOUNDS_CHECK_MODE: BoundsCheckMode.NONE,
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: False,
}

DEFAULT_SHARDERS: List[ModuleSharder[torch.nn.Module]] = [
Expand Down
Loading
Loading