Skip to content

Commit f642579

Browse files
committed
Add sample test for pre-allocated outputs
1 parent e23dabb commit f642579

File tree

2 files changed

+193
-0
lines changed

2 files changed

+193
-0
lines changed

models/split.onnx

202 Bytes
Binary file not shown.

test-pre-allocation.html

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
<!DOCTYPE html>
2+
<html lang="en">
3+
4+
<head>
5+
<meta charset="UTF-8" />
6+
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
7+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
8+
<title>ONNXRuntime Web Test - Pre-allocate Output Tensor</title>
9+
</head>
10+
<style>
11+
body {
12+
font-family: sans-serif;
13+
padding: 20px;
14+
}
15+
16+
h1 {
17+
color: #425066;
18+
font-size: 31px;
19+
margin-top: 0;
20+
}
21+
22+
.loading-stats {
23+
color: #aaa;
24+
font-size: 12px;
25+
margin-top: -12px;
26+
}
27+
28+
.hide {
29+
display: none;
30+
}
31+
32+
.content {
33+
margin-top: 30px;
34+
}
35+
36+
div {
37+
margin-top: 20px;
38+
}
39+
</style>
40+
41+
<body>
42+
<h1>ONNXRuntime Web Test - Pre-allocate Output Tensor</h1>
43+
44+
<!-- Loading status -->
45+
<div class="loading-stats">Choose options then click 'Run'...</div>
46+
<div>
47+
Pre-allocate Output Type:
48+
<select id="preAllocateType">
49+
<option value="gpu-one">Pre-allocate One GPU Tensor</option>
50+
<option value="gpu-all">Pre-allocate All GPU Tensors</option>
51+
<option value="cpu">Pre-allocate All CPU Tensors</option>
52+
</select>
53+
</div>
54+
<div>
55+
<input type="button" value="Run" id="run" />
56+
</div>
57+
<div id="status" style="font: 1em sans-serif"></div>
58+
<script src="https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0-dev.20250810-5d77b73c4e/dist/ort.webgpu.min.js"
59+
integrity="sha256-5yqgD+GVsK1MN4MBoXoYfzfrrz4C/FBG6p+tJaCItJs=" crossorigin="anonymous"></script>
60+
<script>
61+
const log = (i) => {
62+
console.log(i);
63+
document.getElementById('status').innerText +=
64+
`\n[${performance.now().toFixed(3)}] ` + i;
65+
}
66+
67+
ort.env.wasm.numThreads = 4;
68+
ort.env.wasm.simd = true;
69+
ort.env.wasm.proxy = false;
70+
ort.env.logLevel = 'error';
71+
72+
const calcNormalizedBufferSize = size => Math.ceil(Number(size) / 16) * 16;
73+
74+
const downloadGpuData = async (device, gpuBuffer, originalSize) => {
75+
const bufferSize = calcNormalizedBufferSize(originalSize);
76+
const gpuReadBuffer = device.createBuffer({
77+
size: bufferSize,
78+
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
79+
});
80+
try {
81+
const commandEncoder = device.createCommandEncoder();
82+
83+
commandEncoder.copyBufferToBuffer(
84+
gpuBuffer /* source buffer */,
85+
0 /* source offset */,
86+
gpuReadBuffer /* destination buffer */,
87+
0 /* destination offset */,
88+
bufferSize /* size */,
89+
);
90+
device.queue.submit([commandEncoder.finish()]);
91+
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
92+
93+
const arrayBuffer = gpuReadBuffer.getMappedRange();
94+
95+
// the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the
96+
// ArrayBuffer.
97+
return new Uint8Array(arrayBuffer.slice(0, originalSize));
98+
} catch (e) {
99+
log(e);
100+
} finally {
101+
gpuReadBuffer.destroy();
102+
}
103+
};
104+
105+
const createGpuTensor = (device, dataType, dims, bufferSize) => {
106+
const gpuBuffer = device.createBuffer({
107+
usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE,
108+
size: calcNormalizedBufferSize(bufferSize),
109+
});
110+
return ort.Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims });
111+
};
112+
113+
async function run() {
114+
const preAllocateType = document.getElementById('preAllocateType').value;
115+
const modelPath = 'models/split.onnx';
116+
log('entering run ...');
117+
try {
118+
const options = {
119+
executionProviders: [{ name: preAllocateType.startsWith('gpu') ? 'webgpu' : 'wasm' }],
120+
};
121+
122+
log('creating session ...');
123+
console.log('sessionOptions: ', options);
124+
const sess = await ort.InferenceSession.create(modelPath, options);
125+
126+
// - Input:
127+
// - name: input, tensor: float32[2,6]
128+
// - Output:
129+
// - name: output_1, tensor: float32[2,3]
130+
// - name: output_2, tensor: float32[2,3]
131+
132+
const feed = {};
133+
const fetches = {};
134+
135+
const inputBuffer = new Float32Array(Array.from({ length: 12 }, (_, i) => i + 1));
136+
feed['input'] = new ort.Tensor('float32', inputBuffer, [2, 6]);
137+
138+
139+
let device;
140+
const outputBufferSize = 2 * 3 * 4; // 4 bytes per float
141+
if (preAllocateType.startsWith('gpu')) {
142+
device = ort.env.webgpu.device;
143+
fetches['output_1'] = createGpuTensor(device, 'float32', [2, 3], outputBufferSize);
144+
145+
if (preAllocateType === 'gpu-all') {
146+
fetches['output_2'] = createGpuTensor(device, 'float32', [2, 3], outputBufferSize);
147+
}
148+
} else {
149+
fetches['output_1'] = new ort.Tensor('float32', new Float32Array(2 * 3), [2, 3]);
150+
fetches['output_2'] = new ort.Tensor('float32', new Float32Array(2 * 3), [2, 3]);
151+
}
152+
153+
log('running ...');
154+
console.log('inputs: ', feed);
155+
156+
const outputs = await sess.run(feed, fetches);
157+
158+
console.log('outputs: ', outputs);
159+
160+
let output1Data = [], output2Data = [];
161+
if (preAllocateType.startsWith('gpu')) {
162+
const output1DataBuffer = await downloadGpuData(device, outputs['output_1'].gpuBufferData, outputBufferSize);
163+
output1Data = Array.from(new Float32Array(output1DataBuffer.buffer));
164+
if (preAllocateType === 'gpu-all') {
165+
const output2DataBuffer = await downloadGpuData(device, outputs['output_2'].gpuBufferData, outputBufferSize);
166+
output2Data = Array.from(new Float32Array(output2DataBuffer.buffer));
167+
} else {
168+
if (outputs['output_2'] === undefined) {
169+
log('output_2 is not defined in the outputs.');
170+
} else {
171+
output2Data = Array.from(outputs['output_2'].cpuData);
172+
}
173+
}
174+
} else {
175+
output1Data = Array.from(outputs['output_1'].cpuData);
176+
output2Data = Array.from(outputs['output_2'].cpuData);
177+
}
178+
log(`output_1: ${output1Data.join(', ')}`);
179+
log(`output_2: ${output2Data.join(', ')}`);
180+
181+
} catch (e) {
182+
log(e);
183+
}
184+
}
185+
186+
const runBtn = document.getElementById('run');
187+
runBtn.onclick = async () => {
188+
await run();
189+
};
190+
</script>
191+
</body>
192+
193+
</html>

0 commit comments

Comments
 (0)