Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
cmake_minimum_required(VERSION 3.23)

project(cudnn_frontend VERSION 1.14.0)
project(cudnn_frontend VERSION 1.14.1)

option(CUDNN_FRONTEND_SKIP_JSON_LIB "Defines whether FE should not include nlohmann/json.hpp." OFF)
option(CUDNN_FRONTEND_BUILD_SAMPLES "Defines if samples are built or not." ON)
Expand Down
57 changes: 54 additions & 3 deletions include/cudnn_frontend/graph_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,10 @@ class Pointwise_attributes : public Attributes<Pointwise_attributes> {
std::optional<float> relu_upper_clip;
std::optional<float> relu_lower_clip_slope;

std::optional<float> swish_beta;
std::optional<float> elu_alpha;
std::optional<float> softplus_beta;

public:
enum class input_names { IN_0, IN_1, IN_2 };
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
Expand All @@ -861,7 +865,10 @@ class Pointwise_attributes : public Attributes<Pointwise_attributes> {
axis,
relu_lower_clip,
relu_upper_clip,
relu_lower_clip_slope)
relu_lower_clip_slope,
swish_beta,
elu_alpha,
softplus_beta)

Pointwise_attributes&
set_mode(PointwiseMode_t const value) {
Expand Down Expand Up @@ -897,6 +904,24 @@ class Pointwise_attributes : public Attributes<Pointwise_attributes> {
this->relu_upper_clip = value;
return *this;
}

Pointwise_attributes&
set_swish_beta(float const value) {
this->swish_beta = value;
return *this;
}

Pointwise_attributes&
set_elu_alpha(float const value) {
this->elu_alpha = value;
return *this;
}

Pointwise_attributes&
set_softplus_beta(float const value) {
this->softplus_beta = value;
return *this;
}
};

class Instancenorm_backward_attributes : public Attributes<Instancenorm_backward_attributes> {
Expand Down Expand Up @@ -1617,6 +1642,7 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
Descale_S,
Scale_S,
Scale_O,
SINK_TOKEN,
};
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
enum class output_names { O, Stats, RNG_DUMP, Amax_S, Amax_O };
Expand Down Expand Up @@ -1813,6 +1839,12 @@ class SDPA_attributes : public Attributes<SDPA_attributes> {
return *this;
}

SDPA_attributes&
set_sink_token(std::shared_ptr<Tensor_attributes> value) {
inputs[SDPA_attributes::input_names::SINK_TOKEN] = std::move(value);
return *this;
}

SDPA_attributes&
set_implementation(AttentionImplementation_t value) {
implementation = value;
Expand Down Expand Up @@ -1904,9 +1936,10 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
Dropout_mask,
Dropout_scale,
Dropout_scale_inv,
SINK_TOKEN,
};
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
enum class output_names { dQ, dK, dV, dBias, RNG_DUMP };
enum class output_names { dQ, dK, dV, dBias, RNG_DUMP, DSINK_TOKEN };
std::unordered_map<output_names, std::shared_ptr<Tensor_attributes>> outputs;
NLOHMANN_DEFINE_TYPE_INTRUSIVE(SDPA_backward_attributes,
name,
Expand Down Expand Up @@ -2080,6 +2113,18 @@ class SDPA_backward_attributes : public Attributes<SDPA_backward_attributes> {
is_deterministic_algorithm = value;
return *this;
}

SDPA_backward_attributes&
set_sink_token(std::shared_ptr<Tensor_attributes> value) {
inputs[SDPA_backward_attributes::input_names::SINK_TOKEN] = value;
return *this;
}

SDPA_backward_attributes&
set_dsink_token(std::shared_ptr<Tensor_attributes> value) {
outputs[SDPA_backward_attributes::output_names::DSINK_TOKEN] = value;
return *this;
}
};

class SDPA_fp8_backward_attributes : public Attributes<SDPA_fp8_backward_attributes> {
Expand Down Expand Up @@ -2222,7 +2267,7 @@ class Softmax_attributes : public Attributes<Softmax_attributes> {
std::optional<bool> use_M_Zinv;

public:
enum class input_names { P };
enum class input_names { P, SINK };
std::unordered_map<input_names, std::shared_ptr<Tensor_attributes>> inputs;
enum class output_names { S, Stats, M, Zinv };
std::unordered_map<output_names, std::shared_ptr<Tensor_attributes>> outputs;
Expand All @@ -2239,6 +2284,12 @@ class Softmax_attributes : public Attributes<Softmax_attributes> {
use_M_Zinv = value;
return *this;
}

Softmax_attributes&
set_sink(std::shared_ptr<Tensor_attributes> value) {
inputs[Softmax_attributes::input_names::SINK] = value;
return *this;
}
};

class Conv_wgrad_attributes : public Attributes<Conv_wgrad_attributes> {
Expand Down
12 changes: 12 additions & 0 deletions include/cudnn_frontend/node/pointwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@ class PointwiseNode : public NodeCRTP<PointwiseNode> {
pointwise_descriptor_builder.setReluUpperClip(attributes.relu_upper_clip.value());
}

if (attributes.swish_beta.has_value()) {
pointwise_descriptor_builder.setSwishBeta(attributes.swish_beta.value());
}

if (attributes.elu_alpha.has_value()) {
pointwise_descriptor_builder.setEluAlpha(attributes.elu_alpha.value());
}

if (attributes.softplus_beta.has_value()) {
pointwise_descriptor_builder.setSoftplusBeta(attributes.softplus_beta.value());
}

pointwise_descriptor_builder.setComputeType(attributes.compute_data_type);
pointwise_descriptor_builder.setMode(attributes.mode);
auto pointwise_descriptor = pointwise_descriptor_builder.build();
Expand Down
49 changes: 43 additions & 6 deletions include/cudnn_frontend/node/scaled_dot_product_flash_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ class SDPANodeBase : public NodeCRTP<DerivedT> {
return validation_result;
}

// return NOT_SET if sink_token present with 9.12 and below
RETURN_CUDNN_FRONTEND_ERROR_IF(detail::get_backend_version() < 91300 &&
attributes.inputs.find(input_names::SINK_TOKEN) != attributes.inputs.end(),
error_code_t::ATTRIBUTE_NOT_SET,
"SDPA with sink_token is not supported before 9.13.");

return {error_code_t::OK, ""};
}

Expand Down Expand Up @@ -508,6 +514,10 @@ class CompositeSDPANode : public SDPANodeBase<CompositeSDPANode> {

auto softmax_attributes =
Softmax_attributes().set_name("softmax").has_stats(true).has_M_Zinv(false); // As this is flash attention
// Set sink for softmax if user has provided a sink tensor
if (attributes.inputs.find(input_names::SINK_TOKEN) != attributes.inputs.end()) {
softmax_attributes.set_sink(attributes.inputs[input_names::SINK_TOKEN]);
}
// Special non-functional-style call. Needed because output already created and provided to user.
softmax(last_output, softmax_attributes, softmax_output, softmax_stats);
last_output = softmax_output;
Expand Down Expand Up @@ -751,13 +761,14 @@ class CompositeSDPABackwardNode : public NodeCRTP<CompositeSDPABackwardNode> {
if (prop.major == 9) {
// validate basic dimension requirements

if ((128 < d_qk) && (d_qk <= 192) && (64 < d_v) && (d_v <= 128)) {
if ((detail::get_backend_version() >= 91100) && (detail::get_backend_version() < 91300)) {

if ((128 < d_qk) && (d_qk <= 192) && (64 < d_v) && (d_v <= 128)) {

// DeepSeek case, 9.11 only supports 192 hidden dim
if (detail::get_backend_version() >= 91100) {
RETURN_CUDNN_FRONTEND_ERROR_IF( ((d_v == 128) && (d_qk == 192)) == false,
error_code_t::GRAPH_NOT_SUPPORTED,
"Num hidden_dim d_v should be equal to 128 if d_qk is 192");
// DeepSeek case, 9.11 only supports 192 hidden dim
RETURN_CUDNN_FRONTEND_ERROR_IF( (d_v != 128) && (d_qk != 192),
error_code_t::GRAPH_NOT_SUPPORTED,
"Num hidden_dim d_v should be equal to 128 if d_qk is 192");
}
}

Expand Down Expand Up @@ -926,6 +937,10 @@ class CompositeSDPABackwardNode : public NodeCRTP<CompositeSDPABackwardNode> {
RETURN_CUDNN_FRONTEND_ERROR_IF(this->context.get_intermediate_data_type() == DataType_t::NOT_SET,
error_code_t::ATTRIBUTE_NOT_SET,
"Intermediate tensor data type needs to be set as internal tensors require it.");
// If dsink is set, sink also needs to be set
RETURN_CUDNN_FRONTEND_ERROR_IF(attributes.outputs.find(output_names::DSINK_TOKEN) != attributes.outputs.end() && attributes.inputs.find(input_names::SINK_TOKEN) == attributes.inputs.end(),
error_code_t::ATTRIBUTE_NOT_SET,
"If dsink is set, sink also needs to be set.");
// clang-format on

return {error_code_t::OK, ""};
Expand Down Expand Up @@ -1120,6 +1135,28 @@ class CompositeSDPABackwardNode : public NodeCRTP<CompositeSDPABackwardNode> {
reduction(last_output, Reduction_attributes().set_name("reduce_dO_o").set_mode(ReductionMode_t::ADD));
last_output->set_dim({b, h_q, s_q, 1}).set_stride({h_q * s_q, s_q, 1, 1});

if (attributes.outputs.find(output_names::DSINK_TOKEN) != attributes.outputs.end()) {
// sub_sink = sink - stats
auto sub_sink = pointwise(attributes.inputs[input_names::SINK_TOKEN],
attributes.inputs[input_names::Stats],
Pointwise_attributes().set_name("sub_sink").set_mode(PointwiseMode_t::SUB));

// exp_sink = exp(sub_sink)
auto exp_sink =
pointwise(sub_sink, Pointwise_attributes().set_name("exp_sink").set_mode(PointwiseMode_t::EXP));

// per_token_grad = exp_sink * last_output
auto per_token_grad =
pointwise(exp_sink,
last_output,
Pointwise_attributes().set_name("mul_exp_sink_last_output").set_mode(PointwiseMode_t::MUL));

// dSink = redduce(per_token_grad)
reduction(per_token_grad,
Reduction_attributes().set_name("reduce_per_token_grad").set_mode(ReductionMode_t::ADD),
attributes.outputs[output_names::DSINK_TOKEN]);
}

// softmax_sum = last_output * dropout_scale
last_output = pointwise(last_output,
attributes.inputs[input_names::Dropout_scale_inv]
Expand Down
Loading
Loading