-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathbenchmark.py
More file actions
38 lines (29 loc) · 1.3 KB
/
benchmark.py
File metadata and controls
38 lines (29 loc) · 1.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""
Shared benchmark problem: RGB-to-grayscale conversion.
Used by all exploit tests.
"""
import torch
import functools
def generate_test_case(size, seed):
"""Generate a test case for the grayscale benchmark."""
gen = torch.Generator(device="cuda")
gen.manual_seed(seed)
x = torch.rand(size, size, 3, device="cuda", dtype=torch.float32, generator=gen).contiguous()
y_gray = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous()
y_red = torch.empty(size, size, device="cuda", dtype=torch.float32).contiguous()
weights = torch.tensor([0.2989, 0.5870, 0.1140], device="cuda", dtype=torch.float32)
expected_gray = torch.sum(x * weights, dim=-1).contiguous()
expected_red = x[..., 0].contiguous()
return (x,), (y_gray, y_red), ((expected_gray, 1e-6, 1e-6), expected_red)
def kernel_generator(module_name):
"""
Import a submission module and return its kernel function.
This runs inside the subprocess, AFTER test cases are generated.
The submission module's top-level code executes at import time.
"""
import importlib
mod = importlib.import_module(module_name)
return mod.kernel
def make_kernel_generator(module_name):
"""Return a picklable kernel_generator partial."""
return functools.partial(kernel_generator, module_name)