Skip to content

Commit a7a0eb7

Browse files
q10facebook-github-bot
authored andcommitted
Report TBE data configuration with EEG-based indices estimation (flesh out D71147675, pt 2)
Summary: - Add the option to report TBE data configuration with EEG-based indices estimation Differential Revision: D71519199
1 parent 3be1d8d commit a7a0eb7

File tree

4 files changed

+162
-4
lines changed

4 files changed

+162
-4
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/bench/tbe_data_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class TBEDataConfig:
4343
D: int
4444
# Generate mixed dimensions if true
4545
mixed_dim: bool
46-
# Whether the table is weighted or not
46+
# Whether the lookup indices are weighted or not
4747
weighted: bool
4848
# Batch parameters
4949
batch_params: BatchParams
@@ -102,7 +102,7 @@ def variable_L(self) -> bool:
102102
return self.pooling_params.sigma_L is not None
103103

104104
def _new_weights(self, size: int) -> Optional[torch.Tensor]:
105-
# per sample weights will always be FP32
105+
# Per-sample weights will always be FP32
106106
return None if not self.weighted else torch.randn(size, device=get_device())
107107

108108
def _generate_batch_sizes(self) -> Tuple[List[int], Optional[List[List[int]]]]:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
from .bench_params_reporter import TBEBenchmarkParamsReporter # noqa F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import io
11+
import logging
12+
import os
13+
from typing import List, Optional
14+
15+
import fbgemm_gpu # noqa F401
16+
import numpy as np
17+
import torch
18+
from fbgemm_gpu.split_table_batched_embeddings_ops_training import (
19+
DenseTableBatchedEmbeddingBagsCodegen, # noqa
20+
SplitTableBatchedEmbeddingBagsCodegen,
21+
)
22+
from fbgemm_gpu.tbe.bench import (
23+
BatchParams,
24+
IndicesParams,
25+
PoolingParams,
26+
TBEDataConfig,
27+
)
28+
29+
# pyre-ignore[16]
30+
open_source: bool = getattr(fbgemm_gpu, "open_source", False)
31+
32+
if open_source:
33+
from fbgemm_gpu.utils import FileStore
34+
else:
35+
from fbgemm_gpu.fb.utils import FileStore
36+
37+
38+
class TBEBenchmarkParamsReporter:
39+
def __init__(
40+
self,
41+
report_interval: int,
42+
report_once: bool = False,
43+
bucket: Optional[str] = None,
44+
path_prefix: Optional[str] = None,
45+
) -> None:
46+
self.report_interval = report_interval
47+
self.report_once = report_once
48+
49+
default_bucket = "/tmp" if open_source else "tlparse_reports"
50+
bucket = (
51+
bucket
52+
if bucket
53+
else os.environ.get("FBGEMM_TBE_REPORTING_BUCKET", default_bucket)
54+
)
55+
self.filestore = FileStore(bucket)
56+
57+
self.logger: logging.Logger = logging.getLogger(__name__)
58+
self.logger.setLevel(logging.INFO)
59+
60+
def report_stats(
61+
self,
62+
embedding_op: SplitTableBatchedEmbeddingBagsCodegen,
63+
indices: torch.Tensor,
64+
offsets: torch.Tensor,
65+
per_sample_weights: Optional[torch.Tensor] = None,
66+
batch_size_per_feature_per_rank: Optional[List[List[int]]] = None,
67+
) -> None:
68+
"""
69+
Print input stats (for debugging purpose only)
70+
71+
Args:
72+
indices (Tensor): Input indices
73+
offsets (Tensor): Input offsets
74+
per_sample_weights (Optional[Tensor]): Input per
75+
sample weights
76+
"""
77+
if embedding_op.iter.item() % self.report_interval == 0:
78+
pass
79+
80+
# Transfer indices back to CPU for EEG analysis
81+
indices_cpu = indices.cpu()
82+
83+
# Extract embedding table specs
84+
embedding_specs = [
85+
embedding_op.embedding_specs[t] for t in embedding_op.feature_table_map
86+
]
87+
rowcounts = [embedding_spec[0] for embedding_spec in embedding_specs]
88+
dims = [embedding_spec[1] for embedding_spec in embedding_specs]
89+
90+
# Set T to be the number of features we are looking at
91+
T = len(embedding_op.feature_table_map)
92+
# Set E to be the median of the rowcounts to avoid biasing the
93+
E = rowcounts[0] if len(set(rowcounts)) == 1 else np.ceil((np.mean(rowcounts)))
94+
# Set mixed_dim to be True if there are multiple dims
95+
mixed_dim = len(set(dims)) > 1
96+
# Set D to be the median of the dims to avoid biasing
97+
D = dims[0] if mixed_dim else np.ceil((np.mean(dims)))
98+
99+
# Compute indices distribution parameters
100+
heavy_hitters, q, s, _, _ = torch.ops.fbgemm.tbe_estimate_indices_distribution(
101+
indices_cpu
102+
)
103+
indices_params = IndicesParams(
104+
heavy_hitters, q, s, indices.dtype, offsets.dtype
105+
)
106+
107+
# Compute batch parameters
108+
batch_params = BatchParams(
109+
B=((offsets.numel() - 1) // T),
110+
sigma_B=(
111+
np.ceil(
112+
np.std([b for bs in batch_size_per_feature_per_rank for b in bs])
113+
)
114+
if batch_size_per_feature_per_rank
115+
else None
116+
),
117+
vbe_distribution=("normal" if batch_size_per_feature_per_rank else None),
118+
vbe_num_ranks=(
119+
len(batch_size_per_feature_per_rank)
120+
if batch_size_per_feature_per_rank
121+
else None
122+
),
123+
)
124+
125+
# Compute pooling parameters
126+
bag_sizes = offsets[1:] - offsets[:-1]
127+
mixed_bag_sizes = len(set(bag_sizes)) > 1
128+
pooling_params = PoolingParams(
129+
L=np.ceil(np.mean(bag_sizes)) if mixed_bag_sizes else bag_sizes[0],
130+
sigma_L=(np.ceil(np.std(bag_sizes)) if mixed_bag_sizes else None),
131+
length_distribution=("normal" if mixed_bag_sizes else None),
132+
)
133+
134+
config = TBEDataConfig(
135+
T=T,
136+
E=E,
137+
D=D,
138+
mixed_dim=mixed_dim,
139+
weighted=(per_sample_weights is not None),
140+
batch_params=batch_params,
141+
indices_params=indices_params,
142+
pooling_params=pooling_params,
143+
use_cpu=(not torch.cuda.is_available()),
144+
)
145+
146+
# Write the TBE config to FileStore
147+
self.filestore.write(
148+
f"tbe-{embedding_op.uuid}-config-estimation-{embedding_op.iter.item()}.json",
149+
io.BytesIO(config.json(format=True).encode()),
150+
)

fbgemm_gpu/fbgemm_gpu/utils/filestore.py

-2
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,6 @@ class FileStore:
3636
bucket: str
3737

3838
def __post_init__(self) -> None:
39-
# self.bucket = bucket
40-
4139
if not os.path.isdir(self.bucket):
4240
raise ValueError(f"Directory {self.bucket} does not exist")
4341

0 commit comments

Comments
 (0)