Skip to content

Commit f90a2d4

Browse files
author
Tixxx
authored
Changes to support TNLRV3 fine-tuning (#4639)
* added reducesumlogexp gradient added test fixed type mismatch when calling cudnnreduce kernel fixed python frontend to remove redundant states to match pytorch state dict
1 parent d8f3e46 commit f90a2d4

File tree

7 files changed

+128
-136
lines changed

7 files changed

+128
-136
lines changed

onnxruntime/core/providers/cuda/reduction/reduction_ops.cc

+9-3
Original file line numberDiff line numberDiff line change
@@ -395,10 +395,12 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
395395
}
396396

397397
CudnnReduceDescriptor reduce_desc;
398-
if (std::is_same<T, MLFloat16>::value)
398+
if (std::is_same<T, MLFloat16>::value) {
399399
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, CudnnTensor::GetDataType<float>(), ReduceTensorIndices));
400-
else
400+
} else {
401401
ORT_RETURN_IF_ERROR(reduce_desc.Set(cudnn_reduce_op, cudnn_type_X, ReduceTensorIndices));
402+
}
403+
402404
const auto one = Consts<CudaT>::One;
403405
const auto zero = Consts<CudaT>::Zero;
404406
CudnnTensor input_tensor;
@@ -437,7 +439,11 @@ Status ReduceComputeCore(CUDAExecutionProvider& cuda_ep, const Tensor& input, Pr
437439
} else {
438440
// Reduce max -- Max/Min will output indices data
439441
CudnnReduceDescriptor reduce_max_desc;
440-
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_type_X, CUDNN_REDUCE_TENSOR_NO_INDICES));
442+
cudnnDataType_t cudnn_reduce_max_type = cudnn_type_X;
443+
if((std::is_same<T, MLFloat16>::value)) {
444+
cudnn_reduce_max_type = CUDNN_DATA_FLOAT;
445+
}
446+
ORT_RETURN_IF_ERROR(reduce_max_desc.Set(CUDNN_REDUCE_TENSOR_MAX, cudnn_reduce_max_type, CUDNN_REDUCE_TENSOR_NO_INDICES));
441447
size_t indices_bytes_max = 0;
442448
CUDNN_RETURN_IF_ERROR(cudnnGetReductionIndicesSize(cuda_ep.PerThreadCudnnHandle(), reduce_max_desc,
443449
input_tensor, output_tensor, &indices_bytes_max));

orttraining/orttraining/core/graph/gradient_builder.cc

+34
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,40 @@ IMPLEMENT_GRADIENT_BUILDER(GetReduceMeanGradient) {
905905
return result;
906906
}
907907

908+
// Reference computation is pytorch's logsumexp_backward
909+
// dx_i = exp(xi) / reduceSum(exp(xi))
910+
// O(0) = log(reduceSum(exp(xi)))
911+
// Self_Sub_Result = I(0) - O(0) = xi - log(sum(exp(xi))) = log( xi / reduceSum(exp(xi)))
912+
// Gradient computation is re-using output and input from forward op, can be a recomputation candidate.
913+
IMPLEMENT_GRADIENT_BUILDER(GetReduceLogSumExpGradient) {
914+
std::vector<NodeDef> result;
915+
auto attributes = SrcNodeAttributes();
916+
bool keepdims = true;
917+
if (attributes.find("keepdims") != attributes.end() &&
918+
attributes.at("keepdims").has_i()) {
919+
keepdims = static_cast<bool>(attributes.at("keepdims").i());
920+
}
921+
922+
ArgDef grad = GO(0);
923+
if (!keepdims && attributes.find("axes") != attributes.end()) {
924+
std::vector<int64_t> axes_values = RetrieveValues<int64_t>(attributes.at("axes"));
925+
grad = IA("Unsqueezed_Grad");
926+
result.push_back(NodeDef("Unsqueeze", {GO(0)}, {grad}, {MakeAttribute("axes", axes_values)}));
927+
928+
result.push_back(NodeDef("Unsqueeze", {O(0)}, {IA("Unsqueezed_Output")}, {MakeAttribute("axes", axes_values)}));
929+
result.push_back(NodeDef("Sub", {I(0), IA("Unsqueezed_Output")}, {IA("Self_Sub_Result")}));
930+
}
931+
else {
932+
result.push_back(NodeDef("Sub", {I(0), O(0)}, {IA("Self_Sub_Result")}));
933+
}
934+
935+
result.push_back(NodeDef("Exp", {IA("Self_Sub_Result")}, {IA("Self_Sub_Result_Exp")}));
936+
937+
result.push_back(NodeDef("Mul", {IA("Self_Sub_Result_Exp"), grad}, {GI(0)}));
938+
939+
return result;
940+
}
941+
908942
IMPLEMENT_GRADIENT_BUILDER(GetReduceSumGradient) {
909943
std::vector<NodeDef> result;
910944
auto attributes = SrcNodeAttributes();

orttraining/orttraining/core/graph/gradient_builder.h

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ DECLARE_GRADIENT_BUILDER(GetMulGradient)
2525
DECLARE_GRADIENT_BUILDER(GetDivGradient)
2626
DECLARE_GRADIENT_BUILDER(GetReduceMeanGradient)
2727
DECLARE_GRADIENT_BUILDER(GetReduceSumGradient)
28+
DECLARE_GRADIENT_BUILDER(GetReduceLogSumExpGradient)
2829
DECLARE_GRADIENT_BUILDER(GetPowGradient)
2930
DECLARE_GRADIENT_BUILDER(GetConcatGradient)
3031
DECLARE_GRADIENT_BUILDER(GetReshapeGradient)

orttraining/orttraining/core/graph/gradient_builder_registry.cc

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() {
5151
REGISTER_GRADIENT_BUILDER("Pow", GetPowGradient);
5252
REGISTER_GRADIENT_BUILDER("ReduceMean", GetReduceMeanGradient);
5353
REGISTER_GRADIENT_BUILDER("ReduceSum", GetReduceSumGradient);
54+
REGISTER_GRADIENT_BUILDER("ReduceLogSumExp", GetReduceLogSumExpGradient);
5455
REGISTER_GRADIENT_BUILDER("Add", GetAddSubGradient);
5556
REGISTER_GRADIENT_BUILDER("Sub", GetAddSubGradient);
5657
REGISTER_GRADIENT_BUILDER("Mul", GetMulGradient);

orttraining/orttraining/python/ort_trainer.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,8 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti
629629
self.world_size = world_size
630630
self.use_mixed_precision = use_mixed_precision
631631

632+
self.original_model_state_keys = list(model.state_dict().keys()) if hasattr(model, 'state_dict') else []
633+
632634
self.session = None
633635
self.device_ = device
634636
self.gradient_accumulation_steps = gradient_accumulation_steps
@@ -773,7 +775,11 @@ def state_dict(self):
773775
if n.name not in torch_state:
774776
torch_state[n.name] = torch.from_numpy(numpy_helper.to_array(n))
775777

776-
return torch_state
778+
# Need to remove redundant initializers and name suffices to map back to original torch state names
779+
torch_state_to_return = {key: torch_state[key] for key in self.original_model_state_keys if key in torch_state} \
780+
if self.original_model_state_keys \
781+
else torch_state
782+
return torch_state_to_return
777783

778784
def load_state_dict(self, state_dict, strict=False):
779785
# Note: It may happen ONNX model has not yet been initialized

orttraining/orttraining/test/gradient/gradient_op_test_utils.h

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
namespace onnxruntime {
99
namespace test {
10+
using TestDataVector = std::tuple<std::vector<std::vector<TensorInfo>>, // Input data
11+
std::vector<std::vector<TensorInfo>>, // output data
12+
std::vector<std::vector<onnx::AttributeProto>>>; //attribute
1013

1114
class GradientOpTester : public OpTester {
1215
public:
@@ -39,3 +42,4 @@ class GradientOpTester : public OpTester {
3942
};
4043
} // namespace test
4144
} // namespace onnxruntime
45+

orttraining/orttraining/test/gradient/gradient_ops_test.cc

+72-132
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,70 @@ static bool IsErrorWithinTolerance(float error, float tolerance) {
3838
#define EXPECT_IS_TINY(max_error) \
3939
EXPECT_IS_TINIER_THAN(max_error, 1.5e-2f)
4040

41+
static void RunReductionTests(const OpDef& op_def) {
42+
43+
TestDataVector test_data(
44+
// Input X
45+
{
46+
{{4, 3, 2}},
47+
{{4, 3, 2}},
48+
{{4, 3, 2}},
49+
{{4, 3, 2}},
50+
{{4, 3, 2}},
51+
{{4, 3, 2}},
52+
{{4, 3, 2}},
53+
{{4, 3, 2}},
54+
},
55+
// Input Y
56+
{
57+
{{1, 1, 1}},
58+
{{}},
59+
{{1, 3, 1}},
60+
{{2}},
61+
{{4, 1, 2}},
62+
{{4, 3}},
63+
{{4, 1, 2}},
64+
{{4}}
65+
},
66+
// Attributes
67+
{
68+
// default
69+
{},
70+
// axes = [0, 1, 2], keepdims = 0
71+
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
72+
MakeAttribute("keepdims", int64_t(0))},
73+
// axes = [0, 2], keepdims = 1
74+
{MakeAttribute("axes", std::vector<int64_t>{0, 2})},
75+
// axes = [0, 1], keepdims = 0
76+
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
77+
MakeAttribute("keepdims", int64_t(0))},
78+
// axes = [1], keepdims = 1
79+
{MakeAttribute("axes", std::vector<int64_t>{1}),
80+
MakeAttribute("keepdims", int64_t(1))},
81+
// axes = [2], keepdims = 0
82+
{MakeAttribute("axes", std::vector<int64_t>{2}),
83+
MakeAttribute("keepdims", int64_t(0))},
84+
// axes = [-2], keepdims = 1
85+
{MakeAttribute("axes", std::vector<int64_t>{-2}),
86+
MakeAttribute("keepdims", int64_t(1))},
87+
// axes = [-2, -1], keepdims = 0
88+
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
89+
MakeAttribute("keepdims", int64_t(0))}
90+
});
91+
92+
GradientChecker<float, float, float> gradient_checker;
93+
94+
float max_error;
95+
96+
for (size_t i = 0; i < std::get<0>(test_data).size(); i++) {
97+
max_error = 0;
98+
gradient_checker.ComputeGradientError(op_def, std::get<0>(test_data)[i],
99+
std::get<1>(test_data)[i], &max_error,
100+
std::get<2>(test_data)[i]);
101+
EXPECT_IS_TINY(max_error);
102+
}
103+
}
104+
41105
template <typename T>
42106
void GenerateRandomDataWithOneHot(
43107
std::vector<std::vector<float>>& x_datas,
@@ -426,149 +490,24 @@ TEST(GradientCheckerTest, GemmGrad) {
426490
}
427491

428492
TEST(GradientCheckerTest, ReduceMeanGrad) {
429-
float max_error;
430-
GradientChecker<float, float, float> gradient_checker;
431493
// Attribute axes supports negative values from opset 11.
432494
OpDef op_def{"ReduceMean", kOnnxDomain, 11};
433495

434-
// default
435-
{
436-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
437-
EXPECT_IS_TINY(max_error);
438-
}
439-
440-
// TODO: Fix forward kernel behavior for default axes
441-
// default axes, keepdims = 0
442-
/*
443-
{
444-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
445-
{MakeAttribute("keepdims", int64_t(0))});
446-
EXPECT_IS_TINY(max_error);
447-
}
448-
*/
449-
450-
// axes = [0, 1, 2], keepdims = 0
451-
{
452-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
453-
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
454-
MakeAttribute("keepdims", int64_t(0))});
455-
EXPECT_IS_TINY(max_error);
456-
}
457-
458-
// axes = [0, 2], keepdims = 1
459-
{
460-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
461-
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
462-
EXPECT_IS_TINY(max_error);
463-
}
464-
465-
// axes = [0, 1], keepdims = 0
466-
{
467-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
468-
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
469-
MakeAttribute("keepdims", int64_t(0))});
470-
EXPECT_IS_TINY(max_error);
471-
}
472-
473-
// axes = [1], keepdims = 1
474-
{
475-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
476-
{MakeAttribute("axes", std::vector<int64_t>{1}),
477-
MakeAttribute("keepdims", int64_t(1))});
478-
EXPECT_IS_TINY(max_error);
479-
}
480-
481-
// axes = [2], keepdims = 0
482-
{
483-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
484-
{MakeAttribute("axes", std::vector<int64_t>{2}),
485-
MakeAttribute("keepdims", int64_t(0))});
486-
EXPECT_IS_TINY(max_error);
487-
}
488-
489-
// axes = [-2], keepdims = 1
490-
{
491-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
492-
{MakeAttribute("axes", std::vector<int64_t>{-2}),
493-
MakeAttribute("keepdims", int64_t(1))});
494-
EXPECT_IS_TINY(max_error);
495-
}
496-
497-
// axes = [-2, -1], keepdims = 0
498-
{
499-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4}}, &max_error,
500-
{MakeAttribute("axes", std::vector<int64_t>{-2, -1}),
501-
MakeAttribute("keepdims", int64_t(0))});
502-
EXPECT_IS_TINY(max_error);
503-
}
496+
RunReductionTests(op_def);
504497
}
505498

506499
TEST(GradientCheckerTest, ReduceSumGrad) {
507-
float max_error;
508-
GradientChecker<float, float, float> gradient_checker;
509500
// Attribute axes supports negative values from opset 11.
510501
OpDef op_def{"ReduceSum", kOnnxDomain, 11};
511502

512-
// default
513-
{
514-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 1, 1}}, &max_error);
515-
EXPECT_IS_TINY(max_error);
516-
}
517-
518-
// axes = [0, 1, 2], keepdims = 0
519-
{
520-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{}}, &max_error,
521-
{MakeAttribute("axes", std::vector<int64_t>{0, 1, 2}),
522-
MakeAttribute("keepdims", int64_t(0))});
523-
EXPECT_IS_TINY(max_error);
524-
}
525-
526-
// axes = [0, 2], keepdims = 1
527-
{
528-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{1, 3, 1}}, &max_error,
529-
{MakeAttribute("axes", std::vector<int64_t>{0, 2})});
530-
EXPECT_IS_TINY(max_error);
531-
}
532-
533-
// axes = [0, 1], keepdims = 0
534-
{
535-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{2}}, &max_error,
536-
{MakeAttribute("axes", std::vector<int64_t>{0, 1}),
537-
MakeAttribute("keepdims", int64_t(0))});
538-
EXPECT_IS_TINY(max_error);
539-
}
540-
541-
// axes = [1], keepdims = 1
542-
{
543-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
544-
{MakeAttribute("axes", std::vector<int64_t>{1}),
545-
MakeAttribute("keepdims", int64_t(1))});
546-
EXPECT_IS_TINY(max_error);
547-
}
548-
549-
// axes = [2], keepdims = 0
550-
{
551-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 3}}, &max_error,
552-
{MakeAttribute("axes", std::vector<int64_t>{2}),
553-
MakeAttribute("keepdims", int64_t(0))});
554-
EXPECT_IS_TINY(max_error);
555-
}
503+
RunReductionTests(op_def);
504+
}
556505

557-
// axes = [-2], keepdims = 1
558-
{
559-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{4, 1, 2}}, &max_error,
560-
{MakeAttribute("axes", std::vector<int64_t>{-2}),
561-
MakeAttribute("keepdims", int64_t(1))});
562-
EXPECT_IS_TINY(max_error);
563-
}
506+
TEST(GradientCheckerTest, ReduceLogSumExpGrad) {
507+
// Attribute axes supports negative values from opset 11.
508+
OpDef op_def{"ReduceLogSumExp", kOnnxDomain, 11};
564509

565-
// axes = [-1, -3], keepdims = 0
566-
{
567-
gradient_checker.ComputeGradientError(op_def, {{4, 3, 2}}, {{3}}, &max_error,
568-
{MakeAttribute("axes", std::vector<int64_t>{-1, -3}),
569-
MakeAttribute("keepdims", int64_t(0))});
570-
EXPECT_IS_TINY(max_error);
571-
}
510+
RunReductionTests(op_def);
572511
}
573512

574513
#ifndef USE_CUDA
@@ -1960,3 +1899,4 @@ TEST(GradientCheckerTest, ExpandGrad) {
19601899
} // namespace onnxruntime
19611900

19621901
#endif // NDEBUG
1902+

0 commit comments

Comments
 (0)