Skip to content

Commit 57c932d

Browse files
q10facebook-github-bot
authored andcommitted
Update the rowwise adagrad optimizer to leverage optimizer state offloading, v4, backend (pytorch#4195)
Summary: X-link: facebookresearch/FBGEMM#1271 Update the rowwise adagrad optimizer to leverage optimizer state offloading, v4. It is a revision of D74827718 to make the flag an SSD-specific flag, as opposed to optimizer-specific flag. By making this an SSD-specific flag, we are expressing clear intent on the flag's use. This diff adds support for leveraging optimizer state offloading to make optimizer state updates, starting with the rowwise adagrad optimizer. - Add ssd-specific flag `enable_optimizer_offloading` to the table update kernel to enable handling optimizer offloading, starting with the rowwise adagrad case - Propagate the flag upwards to `torch.ops.fbgemm.{{ mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2` Differential Revision: D75329024
1 parent 0064436 commit 57c932d

9 files changed

+131
-6
lines changed

fbgemm_gpu/codegen/genscript/optimizers.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,33 @@ def rowwise_adagrad() -> Dict[str, Any]:
186186
g_local_sum_square += gx * gx + gy * gy + gz * gz + gw * gw;
187187
"""
188188
)
189-
split_precomputation += """
189+
split_precomputation += """
190+
// Define the rowwise adagrad optimizer state struct view
191+
struct [[maybe_unused]] OptimizerState {
192+
at::acc_type<cache_t, true> momentum;
193+
};
194+
190195
const at::acc_type<cache_t, true> g_avg_square =
191196
GROUP_REDUCE_ALL_SUM(g_local_sum_square, at::acc_type<cache_t, true>) / D;
192197
193198
at::acc_type<cache_t, true> multiplier = 0.0;
194199
at::acc_type<cache_t, true> correction = 0.0;
195-
if (threadIdx.x == 0) {
196-
at::acc_type<cache_t, true> new_sum_square_grads = momentum1[idx] + g_avg_square;
197-
momentum1[idx] = new_sum_square_grads;
200+
if (threadIdx.x == 0) {
201+
auto new_sum_square_grads = g_avg_square;
202+
203+
// Update the optimizer state. Use optimizer state offloading only if
204+
// SSD and if enabled by the user
205+
if (enable_optimizer_offloading) {
206+
// Fetch the pointer to the optimizer state along the cache row
207+
auto* optimizer = weight_row_template.template optimizer_state_ptr<OptimizerState>();
208+
new_sum_square_grads += optimizer->momentum;
209+
optimizer->momentum = new_sum_square_grads;
210+
211+
} else {
212+
new_sum_square_grads += momentum1[idx];
213+
momentum1[idx] = new_sum_square_grads;
214+
}
215+
198216
multiplier = learning_rate / (sqrtf(new_sum_square_grads) + eps);
199217
if (weight_decay_mode == 1) {
200218
// L2 regularization

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_template.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,9 @@ enum SSDTensor {
187187
use_uniq_cache_locations_bwd,
188188
use_homogeneous_placements,
189189
{%- endif %}
190+
{%- if ssd %}
191+
enable_optimizer_offloading,
192+
{%- endif %}
190193
{%- if is_gwd %}
191194
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
192195
prev_iter_dev,
@@ -350,6 +353,9 @@ enum SSDTensor {
350353
is_experimental,
351354
use_uniq_cache_locations_bwd,
352355
use_homogeneous_placements,
356+
{%- if ssd %}
357+
enable_optimizer_offloading,
358+
{%- endif %}
353359
{%- if is_gwd %}
354360
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
355361
prev_iter_dev,
@@ -520,6 +526,9 @@ Tensor
520526
{%- if not dense %}
521527
const bool use_uniq_cache_locations,
522528
const bool use_homogeneous_placements,
529+
{%- if ssd %}
530+
const bool enable_optimizer_offloading,
531+
{%- endif %}
523532
{%- endif %}
524533
{%- if is_gwd %}
525534
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
@@ -609,6 +618,9 @@ class {{ autograd_func }} :
609618
const bool is_experimental,
610619
const bool use_uniq_cache_locations_bwd,
611620
const bool use_homogeneous_placements,
621+
{%- if ssd %}
622+
const bool enable_optimizer_offloading,
623+
{%- endif %}
612624
{%- if is_gwd %}
613625
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
614626
const std::optional<Tensor>& prev_iter_dev,
@@ -783,6 +795,11 @@ class {{ autograd_func }} :
783795
ctx->saved_data["use_uniq_cache_locations_bwd"] = use_uniq_cache_locations_bwd;
784796
ctx->saved_data["use_homogeneous_placements"] = use_homogeneous_placements;
785797
{%- endif %}
798+
799+
{%- if ssd %}
800+
ctx->saved_data["enable_optimizer_offloading"] = enable_optimizer_offloading;
801+
{%- endif %}
802+
786803
{%- if is_gwd %}
787804
{%- if "iter" not in args.split_function_arg_names %}
788805
ctx->saved_data["iter"] = iter;
@@ -900,6 +917,11 @@ class {{ autograd_func }} :
900917
const auto use_homogeneous_placements =
901918
ctx->saved_data["use_homogeneous_placements"].toBool();
902919
{%- endif %}
920+
921+
{%- if ssd %}
922+
const auto enable_optimizer_offloading =
923+
ctx->saved_data["enable_optimizer_offloading"].toBool();
924+
{%- endif %}
903925

904926
{%- if is_gwd %}
905927
{%- if "iter" not in args.split_function_arg_names %}
@@ -1065,6 +1087,9 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function(
10651087
const bool is_experimental_tbe = false, // formerly named is_experimental
10661088
const bool use_uniq_cache_locations_bwd = false,
10671089
const bool use_homogeneous_placements = false,
1090+
{%- if ssd %}
1091+
const bool enable_optimizer_offloading = false,
1092+
{%- endif %}
10681093
const std::optional<Tensor>& uvm_cache_stats = std::nullopt,
10691094
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
10701095
const std::optional<Tensor>& prev_iter_dev = std::nullopt,
@@ -1185,6 +1210,9 @@ TORCH_LIBRARY_FRAGMENT({{ lib_name }}, m) {
11851210
" bool is_experimental=False, "
11861211
" bool use_uniq_cache_locations_bwd=False, "
11871212
" bool use_homogeneous_placements=False, "
1213+
{%- if ssd %}
1214+
" bool enable_optimizer_offloading=False, "
1215+
{%- endif %}
11881216
" Tensor? uvm_cache_stats=None, "
11891217
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
11901218
" Tensor? prev_iter_dev=None, "

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_cta_template.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
155155
{%- endif %}
156156
const float gwd_lower_bound,
157157
{%- endif %}
158+
{%- if ssd %}
159+
const bool enable_optimizer_offloading,
160+
{%- endif %}
158161
{%- if is_index_select %}
159162
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
160163
const bool permute_output_dim_0_1
@@ -386,6 +389,9 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
386389
{%- endif %}
387390
shfl_sync_mask,
388391
max_vecs,
392+
{%- if ssd %}
393+
enable_optimizer_offloading,
394+
{%- endif %}
389395
{{ args.split_kernel_arg_names | join(", ") }}
390396
);
391397
{%- else %}
@@ -523,6 +529,9 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row
523529
{%- endif %}
524530
const float gwd_lower_bound,
525531
{%- endif %}
532+
{%- if ssd %}
533+
const bool enable_optimizer_offloading,
534+
{%- endif %}
526535
{%- if is_index_select %}
527536
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
528537
const bool permute_output_dim_0_1

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
133133
{%- endif %}
134134
const float gwd_lower_bound,
135135
{%- endif %}
136+
{%- if ssd %}
137+
const bool enable_optimizer_offloading,
138+
{%- endif %}
136139
{%- if is_index_select %}
137140
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
138141
const bool permute_output_dim_0_1
@@ -296,6 +299,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
296299
{%- endif %}
297300
shfl_sync_mask,
298301
max_vecs,
302+
{%- if ssd %}
303+
enable_optimizer_offloading,
304+
{%- endif %}
299305
{{ args.split_kernel_arg_names | join(", ") }}
300306
);
301307
{%- else %}
@@ -426,6 +432,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
426432
{%- endif %}
427433
const float gwd_lower_bound,
428434
{%- endif %}
435+
{%- if ssd %}
436+
const bool enable_optimizer_offloading,
437+
{%- endif %}
429438
{%- if is_index_select %}
430439
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
431440
const bool permute_output_dim_0_1

fbgemm_gpu/codegen/training/backward/embedding_backward_split_meta_template.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ Tensor {{ mdesc }}_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ desc
110110
const bool use_uniq_cache_locations,
111111
const bool use_homogeneous_placements,
112112
{%- endif %}
113+
{%- if ssd %}
114+
const bool enable_optimizer_offloading,
115+
{%- endif %}
113116
{%- if is_index_select %}
114117
const Tensor& grad_offsets,
115118
const Tensor& total_L_offsets,

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ batch_index_select_dim0_codegen_backward_kernel_cta_per_row(
130130
{%- endif %}
131131
const float gwd_lower_bound,
132132
{%- endif %}
133+
{%- if ssd %}
134+
const bool enable_optimizer_offloading,
135+
{%- endif %}
133136
{%- if is_index_select %}
134137
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
135138
const bool permute_output_dim_0_1
@@ -213,6 +216,9 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
213216
{%- endif %}
214217
const float gwd_lower_bound,
215218
{%- endif %}
219+
{%- if ssd %}
220+
const bool enable_optimizer_offloading,
221+
{%- endif %}
216222
{%- if is_index_select %}
217223
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> grad_offsets,
218224
const bool permute_output_dim_0_1
@@ -572,6 +578,9 @@ Tensor {{ embedding_cuda_op }}(
572578
const bool use_uniq_cache_locations,
573579
const bool use_homogeneous_placements,
574580
{%- endif %}
581+
{%- if ssd %}
582+
const bool enable_optimizer_offloading,
583+
{%- endif %}
575584
{%- if is_index_select %}
576585
const Tensor& grad_offsets,
577586
const Tensor& total_L_offsets,
@@ -1132,6 +1141,9 @@ Tensor {{ embedding_cuda_op }}(
11321141
{%- endif %}
11331142
gwd_lower_bound,
11341143
{%- endif %}
1144+
{%- if ssd %}
1145+
enable_optimizer_offloading,
1146+
{%- endif %}
11351147
{%- if is_index_select %}
11361148
grad_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
11371149
permute_output_dim_0_1
@@ -1288,6 +1300,9 @@ Tensor {{ embedding_cuda_op }}(
12881300
{%- endif %}
12891301
gwd_lower_bound,
12901302
{%- endif %}
1303+
{%- if ssd %}
1304+
enable_optimizer_offloading,
1305+
{%- endif %}
12911306
{%- if is_index_select %}
12921307
grad_offsets.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
12931308
permute_output_dim_0_1
@@ -1380,6 +1395,9 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
13801395
" bool use_uniq_cache_locations, "
13811396
" bool use_homogeneous_placements, "
13821397
{%- endif %}
1398+
{%- if ssd %}
1399+
" bool enable_optimizer_offloading, "
1400+
{%- endif %}
13831401
{%- if is_gwd_kernel %}
13841402
{%- if "prev_iter_dev" not in args.split_function_arg_names %}
13851403
" Tensor prev_iter_dev, "

fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
5555
{%- endif %}
5656
const uint32_t shfl_sync_mask,
5757
const int32_t max_vecs_per_thread,
58+
{%- if ssd %}
59+
const bool enable_optimizer_offloading,
60+
{%- endif %}
5861
{{ args.split_ref_kernel_args | replace_pta_namespace() | join(",\n ") }}
5962
) {
6063
constexpr auto kIsInt8 = std::is_same_v<emb_t, uint8_t>;
@@ -113,6 +116,10 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
113116
}
114117
}
115118

119+
{%- if not ssd %}
120+
constexpr auto enable_optimizer_offloading = false;
121+
{%- endif %}
122+
116123
{{ split_precomputation }}
117124

118125
{# /* Note: technically, global weight decay (gwd) compensation should be done before

fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ enum SSDTensor {
251251
{%- endif %}
252252
const bool /*use_uniq_cache_locations_bwd*/,
253253
const bool /*use_homogeneous_placements*/,
254+
{%- if ssd %}
255+
const bool /*enable_optimizer_offloading*/,
256+
{%- endif %}
254257
{%- if is_gwd %}
255258
{%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %}
256259
const Tensor& /*prev_iter_dev*/,
@@ -319,6 +322,9 @@ enum SSDTensor {
319322
{%- if not dense %}
320323
use_uniq_cache_locations_bwd,
321324
use_homogeneous_placements,
325+
{%- if ssd %}
326+
enable_optimizer_offloading,
327+
{%- endif %}
322328
{%- endif %}
323329
{%- if is_gwd %}
324330
{%- if "prev_iter_dev" not in args_pt2.split_function_arg_names %}
@@ -399,6 +405,7 @@ enum SSDTensor {
399405
{%- for tensor in ssd_tensors %}
400406
ret.push_back(Variable()); // {{ tensor }}
401407
{%- endfor %}
408+
ret.push_back(Variable()); // enable_optimizer_offloading
402409
{%- endif %}
403410
{{ args_pt2.unified_pt2.split_variables | join("\n") }}
404411
return ret;
@@ -468,6 +475,7 @@ enum SSDTensor {
468475
aux_bool,
469476
{%- if ssd %}
470477
ssd_tensors.value(),
478+
enable_optimizer_offloading,
471479
{%- endif %}
472480
{{ args_pt2.unified_pt2.split_function_arg_names | join(", ") }}
473481
{%- endif %}
@@ -628,6 +636,7 @@ class {{ autograd_func }} :
628636
{%- endif %}
629637
{%- if ssd %}
630638
const at::TensorList& ssd_tensors,
639+
const bool enable_optimizer_offloading,
631640
{%- endif %}
632641
{{ args_pt2.unified_pt2.split_function_args | join(", ") }}) {
633642

@@ -817,6 +826,11 @@ class {{ autograd_func }} :
817826
ctx->saved_data["use_uniq_cache_locations_bwd"] = static_cast<bool>(aux_bool[IDX_USE_UNIQ_CACHE_LOCATIONS_BWD]);
818827
ctx->saved_data["use_homogeneous_placements"] = static_cast<bool>(aux_bool[IDX_USE_HOMOGENEOUS_PLACEMENTS]);
819828
{%- endif %}
829+
830+
{%- if ssd %}
831+
ctx->saved_data["enable_optimizer_offloading"] = enable_optimizer_offloading;
832+
{%- endif %}
833+
820834
const auto iter = aux_int[IDX_ITER];
821835
ctx->saved_data["iter"] = iter;
822836
{%- if is_gwd %}
@@ -950,6 +964,11 @@ static torch::autograd::variable_list backward(
950964
const auto use_uniq_cache_locations_bwd = ctx->saved_data["use_uniq_cache_locations_bwd"].toBool();
951965
const auto use_homogeneous_placements = ctx->saved_data["use_homogeneous_placements"].toBool();
952966
{%- endif %}
967+
968+
{%- if ssd %}
969+
const auto enable_optimizer_offloading = ctx->saved_data["enable_optimizer_offloading"].toBool();
970+
{%- endif %}
971+
953972
{%- if is_gwd or "iter" in args_pt2.unified_pt2.split_unpacked_arg_names %}
954973
const auto iter = ctx->saved_data["iter"].toInt();
955974
{%- endif %}
@@ -1148,7 +1167,8 @@ Tensor {{ bwd_mdesc }}_embedding_codegen_lookup_{{ optimizer }}_function_pt2(
11481167
const c10::SymInt max_B_feature_rank = -1,
11491168
{%- if ssd %}
11501169
const c10::SymInt vbe_output_size = -1,
1151-
const std::optional<at::TensorList>& ssd_tensors = std::nullopt
1170+
const std::optional<at::TensorList>& ssd_tensors = std::nullopt,
1171+
bool enable_optimizer_offloading = false
11521172
{%- else %}
11531173
const c10::SymInt vbe_output_size = -1
11541174
{%- endif %}
@@ -1242,7 +1262,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
12421262
" SymInt max_B_feature_rank=-1, "
12431263
{%- if ssd %}
12441264
" SymInt vbe_output_size=-1, "
1245-
" Tensor[]? ssd_tensors=None"
1265+
" Tensor[]? ssd_tensors=None, "
1266+
" bool enable_optimizer_offloading=False "
12461267
{%- else %}
12471268
" SymInt vbe_output_size=-1 "
12481269
{%- endif %}

0 commit comments

Comments
 (0)