@@ -57,25 +57,14 @@ def __init__(
57
57
self .logger : logging .Logger = logging .getLogger (__name__ )
58
58
self .logger .setLevel (logging .INFO )
59
59
60
- def report_stats (
60
+ def extract_params (
61
61
self ,
62
62
embedding_op : SplitTableBatchedEmbeddingBagsCodegen ,
63
63
indices : torch .Tensor ,
64
64
offsets : torch .Tensor ,
65
65
per_sample_weights : Optional [torch .Tensor ] = None ,
66
66
batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
67
- ) -> None :
68
- """
69
- Print input stats (for debugging purpose only)
70
-
71
- Args:
72
- indices (Tensor): Input indices
73
- offsets (Tensor): Input offsets
74
- per_sample_weights (Optional[Tensor]): Input per
75
- sample weights
76
- """
77
- if embedding_op .iter .item () % self .report_interval == 0 :
78
- pass
67
+ ) -> TBEDataConfig :
79
68
80
69
# Transfer indices back to CPU for EEG analysis
81
70
indices_cpu = indices .cpu ()
@@ -89,12 +78,12 @@ def report_stats(
89
78
90
79
# Set T to be the number of features we are looking at
91
80
T = len (embedding_op .feature_table_map )
92
- # Set E to be the median of the rowcounts to avoid biasing the
81
+ # Set E to be the mean of the rowcounts to avoid biasing
93
82
E = rowcounts [0 ] if len (set (rowcounts )) == 1 else np .ceil ((np .mean (rowcounts )))
94
83
# Set mixed_dim to be True if there are multiple dims
95
84
mixed_dim = len (set (dims )) > 1
96
- # Set D to be the median of the dims to avoid biasing
97
- D = dims [0 ] if mixed_dim else np .ceil ((np .mean (dims )))
85
+ # Set D to be the mean of the dims to avoid biasing
86
+ D = dims [0 ] if not mixed_dim else np .ceil ((np .mean (dims )))
98
87
99
88
# Compute indices distribution parameters
100
89
heavy_hitters , q , s , _ , _ = torch .ops .fbgemm .tbe_estimate_indices_distribution (
@@ -123,15 +112,15 @@ def report_stats(
123
112
)
124
113
125
114
# Compute pooling parameters
126
- bag_sizes = offsets [1 :] - offsets [:- 1 ]
115
+ bag_sizes = ( offsets [1 :] - offsets [:- 1 ]). tolist ()
127
116
mixed_bag_sizes = len (set (bag_sizes )) > 1
128
117
pooling_params = PoolingParams (
129
118
L = np .ceil (np .mean (bag_sizes )) if mixed_bag_sizes else bag_sizes [0 ],
130
119
sigma_L = (np .ceil (np .std (bag_sizes )) if mixed_bag_sizes else None ),
131
120
length_distribution = ("normal" if mixed_bag_sizes else None ),
132
121
)
133
122
134
- config = TBEDataConfig (
123
+ return TBEDataConfig (
135
124
T = T ,
136
125
E = E ,
137
126
D = D ,
@@ -143,8 +132,31 @@ def report_stats(
143
132
use_cpu = (not torch .cuda .is_available ()),
144
133
)
145
134
146
- # Write the TBE config to FileStore
147
- self .filestore .write (
148
- f"tbe-{ embedding_op .uuid } -config-estimation-{ embedding_op .iter .item ()} .json" ,
149
- io .BytesIO (config .json (format = True ).encode ()),
150
- )
135
+ def report_stats (
136
+ self ,
137
+ embedding_op : SplitTableBatchedEmbeddingBagsCodegen ,
138
+ indices : torch .Tensor ,
139
+ offsets : torch .Tensor ,
140
+ per_sample_weights : Optional [torch .Tensor ] = None ,
141
+ batch_size_per_feature_per_rank : Optional [List [List [int ]]] = None ,
142
+ ) -> None :
143
+ """
144
+ Print input stats (for debugging purpose only)
145
+
146
+ Args:
147
+ indices (Tensor): Input indices
148
+ offsets (Tensor): Input offsets
149
+ per_sample_weights (Optional[Tensor]): Input per
150
+ sample weights
151
+ """
152
+ if embedding_op .iter .item () % self .report_interval == 0 :
153
+ # Extract TBE config
154
+ config = self .extract_params (
155
+ embedding_op , indices , offsets , per_sample_weights
156
+ )
157
+
158
+ # Write the TBE config to FileStore
159
+ self .filestore .write (
160
+ f"tbe-{ embedding_op .uuid } -config-estimation-{ embedding_op .iter .item ()} .json" ,
161
+ io .BytesIO (config .json (format = True ).encode ()),
162
+ )
0 commit comments