Skip to content

Commit 82a92cb

Browse files
committed
[webgpu] Support Identity
1 parent d8f0318 commit 82a92cb

File tree

3 files changed

+161
-0
lines changed

3 files changed

+161
-0
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/tensor/identity.h"
5+
#include "core/providers/webgpu/webgpu_execution_provider.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
8+
namespace onnxruntime {
9+
namespace webgpu {
10+
11+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
12+
Identity,
13+
kOnnxDomain,
14+
1, 12,
15+
kWebGpuExecutionProvider,
16+
(*KernelDefBuilder::Create())
17+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
18+
.Alias(0, 0),
19+
Identity);
20+
21+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
22+
Identity,
23+
kOnnxDomain,
24+
13, 13,
25+
kWebGpuExecutionProvider,
26+
(*KernelDefBuilder::Create())
27+
.TypeConstraint("T", WebGpuSupportedNumberTypes())
28+
.Alias(0, 0),
29+
Identity);
30+
31+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
32+
Identity,
33+
kOnnxDomain,
34+
14, 15,
35+
kWebGpuExecutionProvider,
36+
(*KernelDefBuilder::Create())
37+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
38+
.Alias(0, 0),
39+
Identity);
40+
41+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
42+
Identity,
43+
kOnnxDomain,
44+
16, 18,
45+
kWebGpuExecutionProvider,
46+
(*KernelDefBuilder::Create())
47+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
48+
.Alias(0, 0),
49+
Identity);
50+
51+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
52+
Identity,
53+
kOnnxDomain,
54+
19, 20,
55+
kWebGpuExecutionProvider,
56+
(*KernelDefBuilder::Create())
57+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
58+
.Alias(0, 0),
59+
Identity);
60+
61+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
62+
Identity,
63+
kOnnxDomain,
64+
21, 22,
65+
kWebGpuExecutionProvider,
66+
(*KernelDefBuilder::Create())
67+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
68+
.Alias(0, 0),
69+
Identity);
70+
71+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
72+
Identity,
73+
kOnnxDomain,
74+
23, 23,
75+
kWebGpuExecutionProvider,
76+
(*KernelDefBuilder::Create())
77+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
78+
.Alias(0, 0),
79+
Identity);
80+
81+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
82+
Identity,
83+
kOnnxDomain,
84+
24, 24,
85+
kWebGpuExecutionProvider,
86+
(*KernelDefBuilder::Create())
87+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
88+
.Alias(0, 0),
89+
Identity);
90+
91+
ONNX_OPERATOR_KERNEL_EX(
92+
Identity,
93+
kOnnxDomain,
94+
25,
95+
kWebGpuExecutionProvider,
96+
(*KernelDefBuilder::Create())
97+
.TypeConstraint("V", WebGpuSupportedNumberTypes())
98+
.Alias(0, 0),
99+
Identity);
100+
101+
} // namespace webgpu
102+
} // namespace onnxruntime
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/framework/op_kernel.h"
7+
#include "core/framework/data_transfer_manager.h"
8+
9+
namespace onnxruntime {
10+
namespace webgpu {
11+
12+
class Identity final : public OpKernel {
13+
public:
14+
explicit Identity(const OpKernelInfo& info) : OpKernel{info} {
15+
}
16+
17+
Status Compute(OpKernelContext* context) const override {
18+
const Tensor* input_tensor = context->Input<Tensor>(0);
19+
if (input_tensor == nullptr) {
20+
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
21+
}
22+
23+
const TensorShape& input_shape = input_tensor->Shape();
24+
Tensor* output_tensor = context->Output(0, input_shape);
25+
26+
const void* source = input_tensor->DataRaw();
27+
void* target = output_tensor->MutableDataRaw();
28+
29+
// If source and target pointers are not equal (non-inplace operation), we need to copy the data.
30+
if (target != source) {
31+
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output_tensor));
32+
}
33+
34+
return Status::OK();
35+
}
36+
};
37+
38+
} // namespace webgpu
39+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,16 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxD
244244
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape);
245245
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape);
246246

247+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Identity);
248+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Identity);
249+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 15, Identity);
250+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 18, Identity);
251+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Identity);
252+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Identity);
253+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Identity);
254+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, 24, Identity);
255+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 25, Identity);
256+
247257
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze);
248258
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze);
249259
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Squeeze);
@@ -535,6 +545,16 @@ std::unique_ptr<KernelRegistry> RegisterKernels(bool enable_graph_capture = fals
535545
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Reshape)>,
536546
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, Reshape)>,
537547

548+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Identity)>,
549+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 13, Identity)>,
550+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 14, 15, Identity)>,
551+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16, 18, Identity)>,
552+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 19, 20, Identity)>,
553+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 21, 22, Identity)>,
554+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 23, 23, Identity)>,
555+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 24, 24, Identity)>,
556+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 25, Identity)>,
557+
538558
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Squeeze)>,
539559
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Squeeze)>,
540560
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 20, Squeeze)>,

0 commit comments

Comments
 (0)