Skip to content

Commit e21b948

Browse files
authored
[webgpu] Optimize string stream used in WebGPU EP (#27223)
### Description Optimize the string stream used in WebGPU EP. ### Motivation and Context The current implementation uses a `absl::OStringStream`, which is faster than `std::ostringstream`. However, it is still slow in the usage of generating the program cache key. From the profiling data, `CalculateProgramCacheKey()` is extremely time consuming. It can consume up to 1/3 of all CPU time inside `WebGpuContext::Run()`: <img width="1035" height="185" alt="image" src="https://github.com/user-attachments/assets/5b9e33cc-cd0a-4ef8-9a92-2ee894b85156" /> The basic analyze shows that most time spent in the `std::basic_ostream operator <<()` implementation, and this is way slower than expected. To optimize, this PR uses a simplified implementation `FastOStringStream`, which does not inherit from `std::basic_ostream`. Instead, the class implementation only includes necessary overrides for the minimum requirements of generating cache key and shader code, to reduce the unnecessary overhead as much as possible. <img width="1016" height="156" alt="image" src="https://github.com/user-attachments/assets/32e3d345-c56d-4e6d-89e1-99cc7b150d8e" /> As a result, the CPU sampling of `CalculateProgramCacheKey()` in the same test dropped from 2555 to 176. Generation TPS of E2E model benchmark on Qwen3-0.6B increased from ~90 to ~130 on Windows11/13900k/RTX4070.
1 parent 685895c commit e21b948

File tree

14 files changed

+122
-62
lines changed

14 files changed

+122
-62
lines changed

onnxruntime/core/providers/webgpu/nn/pool.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,8 @@ Status PoolProgram::GenerateShaderCode(ShaderHelper& shader) const {
8484

8585
constexpr const size_t kStringInitialSize = 128;
8686
if (is_max_pool_) {
87-
std::string f16_min = "f16(-65504)";
88-
89-
SS(f32_min_ss, kStringInitialSize);
90-
f32_min_ss << "f32(" << std::numeric_limits<float>::lowest() << ")";
91-
std::string f32_min = SS_GET(f32_min_ss);
92-
9387
SS(var_decl_ss, kStringInitialSize);
94-
var_decl_ss << " var value = " << (is_float16_ ? f16_min : f32_min) << ";\n";
88+
var_decl_ss << " var value = " << (is_float16_ ? "-65504.0h" : "-3.4028234663852886e+38f") << ";\n";
9589
var_decl_code = SS_GET(var_decl_ss);
9690

9791
sampling_code = " value = max(value, x_val);\n";

onnxruntime/core/providers/webgpu/program.cc

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,17 +49,18 @@ ProgramUniformVariableValue::ProgramUniformVariableValue(ProgramUniformVariableD
4949
memcpy(data.data(), ptr, length * element_byte_size);
5050
}
5151

52-
std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType type) {
53-
os << ProgramUniformVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
54-
return os;
55-
}
52+
#define DEFINE_ENUM_STREAM_OP(StreamType, EnumType, EnumNameArray) \
53+
StreamType& operator<<(StreamType& os, EnumType type) { \
54+
os << EnumNameArray[std::underlying_type<decltype(type)>::type(type)]; \
55+
return os; \
56+
}
5657

57-
std::ostream& operator<<(std::ostream& os, ProgramConstantDataType type) {
58-
os << ProgramConstantDataTypeName[std::underlying_type<decltype(type)>::type(type)];
59-
return os;
60-
}
58+
DEFINE_ENUM_STREAM_OP(std::ostream, ProgramUniformVariableDataType, ProgramUniformVariableDataTypeName)
59+
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramUniformVariableDataType, ProgramUniformVariableDataTypeName)
60+
DEFINE_ENUM_STREAM_OP(std::ostream, ProgramConstantDataType, ProgramConstantDataTypeName)
61+
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramConstantDataType, ProgramConstantDataTypeName)
6162

62-
std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency dep) {
63+
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency dep) {
6364
bool first = true;
6465
if ((dep & ProgramTensorMetadataDependency::Type) == ProgramTensorMetadataDependency::Type) {
6566
os << "Type";
@@ -109,10 +110,7 @@ constexpr std::string_view ProgramVariableDataTypeName[] = {
109110
"i4x8", // Int4x8
110111
};
111112

112-
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType type) {
113-
os << ProgramVariableDataTypeName[std::underlying_type<decltype(type)>::type(type)];
114-
return os;
115-
}
113+
DEFINE_ENUM_STREAM_OP(OStringStream, ProgramVariableDataType, ProgramVariableDataTypeName)
116114
#endif
117115

118116
int NumberOfComponents(ProgramVariableDataType type) {

onnxruntime/core/providers/webgpu/program.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
#include "core/common/safeint.h"
2424
#include "core/framework/tensor.h"
2525

26+
#include "core/providers/webgpu/string_utils.h"
27+
2628
namespace onnxruntime {
2729
namespace webgpu {
2830
class ShaderHelper;
@@ -37,6 +39,7 @@ enum class ProgramUniformVariableDataType {
3739
Int32,
3840
};
3941
std::ostream& operator<<(std::ostream& os, ProgramUniformVariableDataType);
42+
OStringStream& operator<<(OStringStream& os, ProgramUniformVariableDataType);
4043

4144
constexpr size_t ProgramUniformVariableDataTypeSize[] = {sizeof(float), sizeof(uint16_t), sizeof(uint32_t), sizeof(int32_t)};
4245

@@ -80,6 +83,7 @@ enum class ProgramConstantDataType {
8083
Bool
8184
};
8285
std::ostream& operator<<(std::ostream& os, ProgramConstantDataType);
86+
OStringStream& operator<<(OStringStream& os, ProgramConstantDataType);
8387

8488
constexpr std::string_view ProgramConstantDataTypeName[] = {"f32", "f16", "u32", "i32", "bool"};
8589

@@ -158,7 +162,7 @@ enum class ProgramTensorMetadataDependency : int {
158162
TypeAndRank = Type | Rank,
159163
TypeAndShape = Type | Shape,
160164
};
161-
std::ostream& operator<<(std::ostream& os, ProgramTensorMetadataDependency);
165+
OStringStream& operator<<(OStringStream& os, ProgramTensorMetadataDependency);
162166

163167
#if defined(__GNUC__)
164168
#pragma GCC diagnostic push
@@ -216,7 +220,7 @@ enum class ProgramVariableDataType {
216220
// if you add a new type here, you also need to update ProgramVariableDataTypeName
217221
};
218222
#ifndef NDEBUG
219-
std::ostream& operator<<(std::ostream& os, ProgramVariableDataType);
223+
OStringStream& operator<<(OStringStream& os, ProgramVariableDataType);
220224
#endif
221225

222226
int NumberOfComponents(ProgramVariableDataType type);

onnxruntime/core/providers/webgpu/program_cache_key.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace webgpu {
1717

1818
namespace {
1919
// append the info of an input or output to the cachekey
20-
void AppendTensorInfo(std::ostream& ss,
20+
void AppendTensorInfo(OStringStream& ss,
2121
const TensorShape& tensor_shape,
2222
ProgramVariableDataType var_type,
2323
ProgramTensorMetadataDependency dependency,

onnxruntime/core/providers/webgpu/shader_helper.cc

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ ShaderHelper::ShaderHelper(const ProgramBase& program,
3434
dispatch_group_size_z_{dispatch_group_size_z},
3535
program_{program},
3636
program_metadata_{program_metadata},
37-
additional_implementation_ss_{&additional_implementation_},
38-
body_ss_{&body_} {}
37+
additional_implementation_ss_{kStringInitialSizeShaderSourceCodeAdditionalImplementation},
38+
body_ss_{kStringInitialSizeShaderSourceCodeMain} {}
3939

4040
Status ShaderHelper::Init() {
4141
// dispatch group size is normalized so no need to validate it here
@@ -59,8 +59,6 @@ Status ShaderHelper::Init() {
5959
// init body string stream
6060
bool is_1d_dispatch = dispatch_group_size_y_ == 1 && dispatch_group_size_z_ == 1;
6161
bool use_indirect_dispatch = program_.IndirectDispatchTensor() != nullptr;
62-
body_.reserve(4096);
63-
additional_implementation_.reserve(1024);
6462

6563
// append header for main function so it is ready for user to append main function body
6664
body_ss_ << "@compute @workgroup_size(workgroup_size_x, workgroup_size_y, workgroup_size_z)\n"
@@ -384,7 +382,7 @@ Status ShaderHelper::ValidateIndices() const {
384382
return Status::OK();
385383
}
386384

387-
Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const {
385+
Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) {
388386
SS(ss, kStringInitialSizeShaderSourceCode);
389387

390388
//
@@ -633,12 +631,12 @@ Status ShaderHelper::GenerateSourceCode(std::string& code, std::vector<int>& sha
633631
//
634632
// Additional Implementation
635633
//
636-
ss << additional_implementation_;
634+
ss << SS_GET(additional_implementation_ss_);
637635

638636
//
639637
// Main Function Body
640638
//
641-
ss << body_;
639+
ss << SS_GET(body_ss_);
642640
ss << "\n"
643641
"}\n";
644642

onnxruntime/core/providers/webgpu/shader_helper.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ class ShaderHelper final {
108108

109109
private:
110110
template <typename ConstantType> // ConstantType is one of {ProgramConstant, ProgramOverridableConstantValue, ProgramOverridableConstantDefinition}
111-
void WriteConstantValue(std::ostream& ss, const ConstantType& constant) const {
111+
void WriteConstantValue(OStringStream& ss, const ConstantType& constant) const {
112112
switch (constant.type) {
113113
case ProgramConstantDataType::Float16:
114114
ss << constant.f16.ToFloat();
@@ -156,7 +156,7 @@ class ShaderHelper final {
156156
// \param code The generated full WGSL source code.
157157
// \param shape_uniform_ranks The ranks for variables that need a uniform for the shape.
158158
//
159-
Status GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks) const;
159+
Status GenerateSourceCode(std::string& code, std::vector<int>& shape_uniform_ranks);
160160
friend class ProgramManager;
161161

162162
const WebGpuContext& webgpu_context_;
@@ -175,9 +175,7 @@ class ShaderHelper final {
175175
std::vector<std::unique_ptr<ShaderVariableHelper>> input_vars_;
176176
std::vector<std::unique_ptr<ShaderVariableHelper>> output_vars_;
177177
std::vector<std::unique_ptr<ShaderIndicesHelper>> indices_vars_;
178-
std::string additional_implementation_;
179178
OStringStream additional_implementation_ss_;
180-
std::string body_;
181179
OStringStream body_ss_;
182180
};
183181

onnxruntime/core/providers/webgpu/shader_variable.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ ShaderVariableHelper::ShaderVariableHelper(std::string_view name, ProgramVariabl
150150
ORT_ENFORCE(num_components_ > 0, "Invalid number of components for variable ", name_);
151151
}
152152

153-
void ShaderIndicesHelper::Impl(std::ostream& ss) const {
153+
void ShaderIndicesHelper::Impl(OStringStream& ss) const {
154154
// Start generating code
155155

156156
const std::string shape = (usage_ & ShaderUsage::UseUniform) ? "uniforms." + name_ + "_shape" : name_ + "_shape";
@@ -249,7 +249,7 @@ void ShaderIndicesHelper::Impl(std::ostream& ss) const {
249249
}
250250
}
251251

252-
void ShaderVariableHelper::Impl(std::ostream& ss) const {
252+
void ShaderVariableHelper::Impl(OStringStream& ss) const {
253253
ShaderIndicesHelper::Impl(ss);
254254

255255
// Implementation of "fn set_{name}"

onnxruntime/core/providers/webgpu/shader_variable.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ class ShaderIndicesHelper {
130130
protected:
131131
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderIndicesHelper);
132132

133-
void Impl(std::ostream& ss) const;
133+
void Impl(OStringStream& ss) const;
134134

135135
std::string_view IndicesType() const;
136136

@@ -197,7 +197,7 @@ class ShaderVariableHelper : public ShaderIndicesHelper {
197197
private:
198198
ORT_DISALLOW_COPY_AND_ASSIGNMENT(ShaderVariableHelper);
199199

200-
void Impl(std::ostream& ss) const;
200+
void Impl(OStringStream& ss) const;
201201

202202
std::string GetByOffsetImpl(std::string_view offset) const;
203203
std::string SetByOffsetImpl(std::string_view offset, std::string_view value) const;

onnxruntime/core/providers/webgpu/string_macros.h

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
#include "core/providers/webgpu/string_utils.h"
77

88
// macro "SS" - declare an ostream variable and its string buffer
9-
#define SS(ss, reserve_size) \
10-
std::string ss##_str; \
11-
ss##_str.reserve(reserve_size); \
12-
::onnxruntime::webgpu::OStringStream ss(&ss##_str)
9+
#define SS(ss, reserve_size) ::onnxruntime::webgpu::OStringStream ss(reserve_size)
1310

1411
// macro "SS_GET" - get the string from the ostream
15-
#define SS_GET(ss) ss##_str
12+
#define SS_GET(ss) (std::move(ss).str())
1613

1714
// macro "SS_APPEND" - use function call style to append to the ostream
1815
#define SS_APPEND(ss, ...) ::onnxruntime::webgpu::detail::OStringStreamAppend(ss, __VA_ARGS__)

onnxruntime/core/providers/webgpu/string_utils.h

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,15 @@
55

66
#include "core/common/make_string.h"
77

8+
#include <array>
9+
#include <charconv>
10+
811
#ifdef _MSC_VER
912
#pragma warning(push)
1013
// C4702: unreachable code
1114
#pragma warning(disable : 4702)
1215
#endif // _MSC_VER
1316

14-
#include <absl/strings/internal/ostringstream.h>
15-
1617
#ifdef _MSC_VER
1718
#pragma warning(pop)
1819
#endif // _MSC_VER
@@ -22,32 +23,102 @@ namespace webgpu {
2223

2324
constexpr const size_t kStringInitialSizeSetByOffsetImpl = 128;
2425
constexpr const size_t kStringInitialSizeGetByOffsetImpl = 128;
25-
constexpr const size_t kStringInitialSizeShaderSourceCode = 2048;
26-
#ifndef NDEBUG
26+
constexpr const size_t kStringInitialSizeShaderSourceCode = 4096;
27+
constexpr const size_t kStringInitialSizeShaderSourceCodeAdditionalImplementation = 1024;
28+
constexpr const size_t kStringInitialSizeShaderSourceCodeMain = 3068;
2729
constexpr const size_t kStringInitialSizeCacheKey = 512;
28-
#else
29-
constexpr const size_t kStringInitialSizeCacheKey = 256;
30-
#endif
3130

32-
using OStringStream = absl::strings_internal::OStringStream;
31+
namespace detail {
32+
33+
// A simpler and faster ostringstream implementation than absl::strings_internal::OStringStream
34+
//
35+
// This FastOStringStream class is intended to be used in very performance critical paths. It does
36+
// not inherit from std::ostream so that it can avoid the following overheads:
37+
// - locale handling and formatting
38+
// - state management (e.g. error handling, badbit, EOF, I/O sync)
39+
// - unnecessary heap allocations
40+
// - virtual function calls
41+
//
42+
// This class is majorly used for generating shader source code and program cache keys.
43+
//
44+
class FastOStringStream {
45+
public:
46+
explicit FastOStringStream(size_t reserve_size) {
47+
str_.reserve(reserve_size);
48+
}
49+
50+
std::string str() && {
51+
return std::move(str_);
52+
}
53+
54+
// String types
55+
FastOStringStream& operator<<(const char* s) {
56+
str_.append(s);
57+
return *this;
58+
}
59+
60+
FastOStringStream& operator<<(const std::string& s) {
61+
str_.append(s);
62+
return *this;
63+
}
64+
65+
FastOStringStream& operator<<(std::string_view s) {
66+
str_.append(s);
67+
return *this;
68+
}
69+
70+
// Character
71+
FastOStringStream& operator<<(char c) {
72+
str_.push_back(c);
73+
return *this;
74+
}
75+
76+
// Integer types
77+
template <typename T>
78+
std::enable_if_t<std::is_integral_v<T> && !std::is_same_v<T, char>, FastOStringStream&>
79+
operator<<(T value) {
80+
std::array<char, 32> buffer;
81+
auto [ptr, ec] = std::to_chars(buffer.data(), buffer.data() + buffer.size(), value);
82+
str_.append(buffer.data(), ptr - buffer.data());
83+
return *this;
84+
}
85+
86+
// Floating point types
87+
template <typename T>
88+
std::enable_if_t<std::is_floating_point_v<T>, FastOStringStream&>
89+
operator<<(T value) {
90+
std::array<char, 64> buffer;
91+
auto [ptr, ec] = std::to_chars(buffer.data(), buffer.data() + buffer.size(), value);
92+
str_.append(buffer.data(), ptr - buffer.data());
93+
return *this;
94+
}
95+
96+
private:
97+
std::string str_;
98+
};
99+
100+
} // namespace detail
101+
102+
using OStringStream = detail::FastOStringStream;
33103

34104
namespace detail {
35-
inline void OStringStreamAppendImpl(std::ostream& /*ss*/) noexcept {
105+
106+
inline void OStringStreamAppendImpl(OStringStream& /*ss*/) noexcept {
36107
}
37108

38109
template <typename T>
39-
inline void OStringStreamAppendImpl(std::ostream& ss, const T& t) noexcept {
110+
inline void OStringStreamAppendImpl(OStringStream& ss, const T& t) noexcept {
40111
ss << t;
41112
}
42113

43114
template <typename T, typename... Args>
44-
inline void OStringStreamAppendImpl(std::ostream& ss, const T& t, const Args&... args) noexcept {
115+
inline void OStringStreamAppendImpl(OStringStream& ss, const T& t, const Args&... args) noexcept {
45116
OStringStreamAppendImpl(ss, t);
46117
OStringStreamAppendImpl(ss, args...);
47118
}
48119

49120
template <typename... Args>
50-
inline void OStringStreamAppend(std::ostream& ss, const Args&... args) {
121+
inline void OStringStreamAppend(OStringStream& ss, const Args&... args) {
51122
return OStringStreamAppendImpl(ss, ::onnxruntime::detail::if_char_array_make_ptr_t<Args const&>(args)...);
52123
}
53124

0 commit comments

Comments
 (0)