Skip to content

Commit c663046

Browse files
committed
Compile modeld warps with symbolic camera sizes
1 parent 52e1826 commit c663046

5 files changed

Lines changed: 164 additions & 78 deletions

File tree

selfdrive/modeld/SConscript

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,12 @@ lenv.Command(fn + "_metadata.pkl", [fn + ".onnx"] + tinygrad_files + script_file
121121

122122
dm_w, dm_h = DM_INPUT_SIZE
123123
compile_dm_warp_script = [File(f"{modeld_dir}/compile_dm_warp.py")]
124-
for cam_w, cam_h in CAMERA_CONFIGS:
125-
dm_pkl_path = File(f"models/dm_warp_{cam_w}x{cam_h}_tinygrad.pkl").abspath
126-
cmd = (f'{tg_flags} {mac_brew_string} python3 {modeld_dir}/compile_dm_warp.py '
127-
f'--camera-resolution {cam_w}x{cam_h} --warp-to {dm_w}x{dm_h} '
128-
f'--output {dm_pkl_path}')
129-
lenv.Command(dm_pkl_path, tinygrad_files + compile_dm_warp_script + compile_modeld_script + [tg_devices_node], cmd)
124+
dm_pkl_path = File("models/dm_warp_tinygrad.pkl").abspath
125+
camera_res_args = ' '.join(f'{cw}x{ch}' for cw, ch in CAMERA_CONFIGS)
126+
cmd = (f'{tg_flags} {mac_brew_string} python3 {modeld_dir}/compile_dm_warp.py '
127+
f'--camera-resolutions {camera_res_args} --warp-to {dm_w}x{dm_h} '
128+
f'--output {dm_pkl_path}')
129+
lenv.Command(dm_pkl_path, tinygrad_files + compile_dm_warp_script + compile_modeld_script + [Value(camera_res_args), tg_devices_node], cmd)
130130

131131
def tg_compile(flags, model_name):
132132
pythonpath_string = 'PYTHONPATH="${PYTHONPATH}:' + env.Dir("#tinygrad_repo").abspath + '"'

selfdrive/modeld/compile_dm_warp.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,54 @@
88
from tinygrad.engine.jit import TinyJit
99

1010
from openpilot.system.camerad.cameras.nv12_info import get_nv12_info
11-
from openpilot.selfdrive.modeld.compile_modeld import NV12Frame, warp_perspective_tinygrad, _parse_size
11+
from openpilot.selfdrive.modeld.compile_modeld import NV12Frame, bind_camera_vars, make_camera_vars, warp_perspective_tinygrad, _parse_size
1212

1313

14-
def make_warp_dm(nv12: NV12Frame, dm_w, dm_h):
15-
cam_w, cam_h, stride, _, _, _ = nv12
16-
stride_pad = stride - cam_w
17-
18-
def warp_dm(input_frame, M_inv):
14+
def make_warp_dm(dm_w, dm_h):
15+
def warp_dm(input_frame, M_inv, cam_w, cam_h, chroma_w, chroma_h, stride, uv_offset):
1916
M_inv = M_inv.to(Device.DEFAULT).realize()
20-
return warp_perspective_tinygrad(input_frame[:cam_h*stride], M_inv,
21-
(dm_w, dm_h), (cam_h, cam_w), stride_pad, border_fill_val=0).reshape(-1, dm_h * dm_w)
17+
return warp_perspective_tinygrad(input_frame, M_inv, (dm_w, dm_h),
18+
(cam_h, cam_w), stride, border_fill_val=0).reshape(-1, dm_h * dm_w)
2219
return warp_dm
2320

2421

25-
def compile_dm_warp(nv12: NV12Frame, dm_w, dm_h, pkl_path):
26-
print(f"Compiling DM warp for {nv12.width}x{nv12.height} -> {dm_w}x{dm_h}...")
22+
def compile_dm_warp(camera_configs: list[NV12Frame], dm_w, dm_h, pkl_path):
23+
print(f"Compiling DM warp for {len(camera_configs)} camera sizes -> {dm_w}x{dm_h}...")
2724

28-
warp_dm_jit = TinyJit(make_warp_dm(nv12, dm_w, dm_h), prune=True)
25+
camera_vars, max_frame_size = make_camera_vars(camera_configs)
26+
warp_dm_jit = TinyJit(make_warp_dm(dm_w, dm_h), prune=True)
2927

3028
for i in range(10):
31-
frame = Tensor.randint(nv12.size, low=0, high=256, dtype='uint8').realize()
29+
nv12 = camera_configs[i % len(camera_configs)]
30+
frame = Tensor.randint(max_frame_size, low=0, high=256, dtype='uint8').realize()
3231
M_inv = Tensor(Tensor.randn(3, 3).mul(8).realize().numpy(), device='NPY')
3332
Device.default.synchronize()
3433
st = time.perf_counter()
35-
warp_dm_jit(frame, M_inv).realize()
34+
warp_dm_jit(frame, M_inv, **bind_camera_vars(camera_vars, nv12)).realize()
3635
mt = time.perf_counter()
3736
Device.default.synchronize()
3837
et = time.perf_counter()
39-
print(f" [{i+1}/10] enqueue {(mt-st)*1e3:6.2f} ms -- total {(et-st)*1e3:6.2f} ms")
38+
print(f" [{i+1}/10] {nv12.width}x{nv12.height} enqueue {(mt-st)*1e3:6.2f} ms -- total {(et-st)*1e3:6.2f} ms")
4039

4140
with open(pkl_path, "wb") as f:
42-
pickle.dump(warp_dm_jit, f)
41+
pickle.dump({
42+
'warp': warp_dm_jit,
43+
'camera_configs': {nv12[:2]: nv12 for nv12 in camera_configs},
44+
'max_frame_size': max_frame_size,
45+
}, f)
4346
print(f" Saved to {pkl_path}")
4447

4548

4649
if __name__ == "__main__":
4750
p = argparse.ArgumentParser()
48-
p.add_argument('--camera-resolution', type=_parse_size, required=True, help='camera resolution WxH')
51+
p.add_argument('--camera-resolution', type=_parse_size, help='camera resolution WxH')
52+
p.add_argument('--camera-resolutions', type=_parse_size, nargs='+', help='camera resolutions WxH (one or more)')
4953
p.add_argument('--warp-to', type=_parse_size, required=True, help='DM input WxH')
5054
p.add_argument('--output', required=True)
5155
args = p.parse_args()
5256

53-
cam_w, cam_h = args.camera_resolution
54-
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
57+
camera_resolutions = args.camera_resolutions or ([args.camera_resolution] if args.camera_resolution else None)
58+
assert camera_resolutions is not None, "one of --camera-resolution or --camera-resolutions is required"
59+
camera_configs = [NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h)) for cam_w, cam_h in camera_resolutions]
5560
dm_w, dm_h = args.warp_to
56-
compile_dm_warp(nv12, dm_w, dm_h, args.output)
61+
compile_dm_warp(camera_configs, dm_w, dm_h, args.output)

selfdrive/modeld/compile_modeld.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ def fetch_fw(path, name, sha256):
2626
_patch_tinygrad_fetch_fw()
2727

2828
from tinygrad.tensor import Tensor
29+
from tinygrad import Variable
2930
from tinygrad.helpers import Context
3031
from tinygrad.device import Device
3132
from tinygrad.engine.jit import TinyJit
3233

3334
from openpilot.common.file_chunker import read_file_chunked
34-
from openpilot.system.hardware.hw import Paths
3535

3636

3737
NV12Frame = namedtuple("NV12Frame", ['width', 'height', 'stride', 'y_height', 'uv_height', 'size'])
@@ -41,8 +41,61 @@ def fetch_fw(path, name, sha256):
4141

4242
WARP_DEV = os.getenv('WARP_DEV')
4343

44+
_ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = None
4445

45-
def warp_perspective_tinygrad(src_flat, M_inv, dst_shape, src_shape, stride_pad, border_fill_val=None):
46+
47+
def _optimize_local_size_or_skip(call, prg):
48+
try:
49+
return _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE(call, prg)
50+
except AssertionError as e:
51+
if str(e) != "all optimize_local_size exec failed":
52+
raise
53+
return None
54+
55+
56+
def _patch_tinygrad_local_size_optimizer():
57+
global _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE
58+
from tinygrad.engine import realize
59+
from tinygrad.uop.ops import Ops, PatternMatcher, UPat
60+
61+
_ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = realize.optimize_local_size
62+
realize.pm_optimize_local_size = PatternMatcher([
63+
(UPat(Ops.CALL, src=(UPat(Ops.PROGRAM, name="prg"),), name="call", allow_any_len=True), _optimize_local_size_or_skip),
64+
])
65+
66+
67+
_patch_tinygrad_local_size_optimizer()
68+
69+
70+
def make_camera_vars(camera_configs: list[NV12Frame]):
71+
max_cam_w = max(nv12.width for nv12 in camera_configs)
72+
max_cam_h = max(nv12.height for nv12 in camera_configs)
73+
max_stride = max(nv12.stride for nv12 in camera_configs)
74+
max_uv_offset = max(nv12.stride * nv12.y_height for nv12 in camera_configs)
75+
max_frame_size = max(nv12.size for nv12 in camera_configs)
76+
return {
77+
'cam_w': Variable('cam_w', 1, max_cam_w),
78+
'cam_h': Variable('cam_h', 1, max_cam_h),
79+
'chroma_w': Variable('chroma_w', 1, max_cam_w // 2),
80+
'chroma_h': Variable('chroma_h', 1, max_cam_h // 2),
81+
'stride': Variable('stride', 1, max_stride),
82+
'uv_offset': Variable('uv_offset', 1, max_uv_offset),
83+
}, max_frame_size
84+
85+
86+
def bind_camera_vars(camera_vars, nv12: NV12Frame):
87+
values = {
88+
'cam_w': nv12.width,
89+
'cam_h': nv12.height,
90+
'chroma_w': nv12.width // 2,
91+
'chroma_h': nv12.height // 2,
92+
'stride': nv12.stride,
93+
'uv_offset': nv12.stride * nv12.y_height,
94+
}
95+
return {k: v.bind(values[k]) for k, v in camera_vars.items()}
96+
97+
98+
def warp_perspective_tinygrad(src_flat, M_inv, dst_shape, src_shape, stride, src_offset=0, x_step=1, channel=0, border_fill_val=None):
4699
w_dst, h_dst = dst_shape
47100
h_src, w_src = src_shape
48101

@@ -61,7 +114,7 @@ def warp_perspective_tinygrad(src_flat, M_inv, dst_shape, src_shape, stride_pad,
61114
y_round = Tensor.round(src_y)
62115
x_nn_clipped = x_round.clip(0, w_src - 1).cast('int')
63116
y_nn_clipped = y_round.clip(0, h_src - 1).cast('int')
64-
idx = y_nn_clipped * (w_src + stride_pad) + x_nn_clipped
117+
idx = y_nn_clipped * stride + x_nn_clipped * x_step + src_offset + channel
65118
sampled = src_flat[idx]
66119

67120
if border_fill_val is None:
@@ -84,26 +137,18 @@ def frames_to_tensor(frames):
84137
return in_img1
85138

86139

87-
def make_frame_prepare(nv12: NV12Frame, model_w, model_h):
88-
cam_w, cam_h, stride, y_height, uv_height, _ = nv12
89-
uv_offset = stride * y_height
90-
stride_pad = stride - cam_w
91-
92-
def frame_prepare_tinygrad(input_frame, M_inv):
140+
def make_frame_prepare(model_w, model_h):
141+
def frame_prepare_tinygrad(input_frame, M_inv, cam_w, cam_h, chroma_w, chroma_h, stride, uv_offset):
93142
# UV_SCALE @ M_inv @ UV_SCALE_INV simplifies to elementwise scaling
94143
M_inv_uv = M_inv * Tensor([[1.0, 1.0, 0.5], [1.0, 1.0, 0.5], [2.0, 2.0, 1.0]], device=WARP_DEV)
95-
# deinterleave NV12 UV plane (UVUV... -> separate U, V)
96-
uv = input_frame[uv_offset:uv_offset + uv_height * stride].reshape(uv_height, stride)
97144
with Context(SPLIT_REDUCEOP=0):
98-
y = warp_perspective_tinygrad(input_frame[:cam_h*stride],
99-
M_inv, (model_w, model_h),
100-
(cam_h, cam_w), stride_pad).realize()
101-
u = warp_perspective_tinygrad(uv[:cam_h//2, :cam_w:2].flatten(),
102-
M_inv_uv, (model_w//2, model_h//2),
103-
(cam_h//2, cam_w//2), 0).realize()
104-
v = warp_perspective_tinygrad(uv[:cam_h//2, 1:cam_w:2].flatten(),
105-
M_inv_uv, (model_w//2, model_h//2),
106-
(cam_h//2, cam_w//2), 0).realize()
145+
y = warp_perspective_tinygrad(input_frame, M_inv, (model_w, model_h),
146+
(cam_h, cam_w), stride).realize()
147+
# Gather directly from interleaved NV12 UV memory so symbolic widths avoid step=2 slicing.
148+
u = warp_perspective_tinygrad(input_frame, M_inv_uv, (model_w//2, model_h//2),
149+
(chroma_h, chroma_w), stride, uv_offset, x_step=2, channel=0).realize()
150+
v = warp_perspective_tinygrad(input_frame, M_inv_uv, (model_w//2, model_h//2),
151+
(chroma_h, chroma_w), stride, uv_offset, x_step=2, channel=1).realize()
107152
yuv = y.cat(u).cat(v).reshape((model_h * 3 // 2, model_w))
108153
tensor = frames_to_tensor(yuv)
109154
return tensor
@@ -148,21 +193,22 @@ def sample_desire(buf, frame_skip):
148193
return buf.reshape(-1, frame_skip, *buf.shape[1:]).max(1).flatten(0, 1).unsqueeze(0)
149194

150195

151-
def make_run_policy(vision_runner, policy_runner, nv12: NV12Frame, model_w, model_h,
196+
def make_run_policy(vision_runner, policy_runner, model_w, model_h,
152197
vision_features_slice, frame_skip, prepare_only=False):
153-
frame_prepare = make_frame_prepare(nv12, model_w, model_h)
198+
frame_prepare = make_frame_prepare(model_w, model_h)
154199
sample_skip_fn = partial(sample_skip, frame_skip=frame_skip)
155200
sample_desire_fn = partial(sample_desire, frame_skip=frame_skip)
156201

157-
def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, tfm, big_tfm, frame, big_frame):
202+
def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, tfm, big_tfm,
203+
frame, big_frame, cam_w, cam_h, chroma_w, chroma_h, stride, uv_offset):
158204
tfm = tfm.to(WARP_DEV)
159205
big_tfm = big_tfm.to(WARP_DEV)
160206
desire = desire.to(Device.DEFAULT)
161207
traffic_convention = traffic_convention.to(Device.DEFAULT)
162208
Tensor.realize(tfm, big_tfm, desire, traffic_convention)
163209

164-
warped_frame = frame_prepare(frame, tfm).unsqueeze(0).to(Device.DEFAULT)
165-
warped_big_frame = frame_prepare(big_frame, big_tfm).unsqueeze(0).to(Device.DEFAULT)
210+
warped_frame = frame_prepare(frame, tfm, cam_w, cam_h, chroma_w, chroma_h, stride, uv_offset).unsqueeze(0).to(Device.DEFAULT)
211+
warped_big_frame = frame_prepare(big_frame, big_tfm, cam_w, cam_h, chroma_w, chroma_h, stride, uv_offset).unsqueeze(0).to(Device.DEFAULT)
166212
img = shift_and_sample(img_q, warped_frame, sample_skip_fn)
167213
big_img = shift_and_sample(big_img_q, warped_big_frame, sample_skip_fn)
168214

@@ -182,21 +228,24 @@ def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, t
182228
return run_policy
183229

184230

185-
def compile_modeld(nv12: NV12Frame, model_w, model_h, prepare_only, frame_skip,
231+
def compile_modeld(camera_configs: list[NV12Frame], model_w, model_h, prepare_only, frame_skip,
186232
vision_runner, policy_runner, vision_metadata, policy_metadata):
187-
print(f"Compiling combined policy JIT for {nv12.width}x{nv12.height} (prepare_only={prepare_only})...")
233+
print(f"Compiling combined policy JIT for {len(camera_configs)} camera sizes (prepare_only={prepare_only})...")
188234

189235
vision_features_slice = vision_metadata['output_slices']['hidden_state']
190236
vision_input_shapes = vision_metadata['input_shapes']
191237
policy_input_shapes = policy_metadata['input_shapes']
192238

193-
_run = make_run_policy(vision_runner, policy_runner, nv12, model_w, model_h,
239+
camera_vars, max_frame_size = make_camera_vars(camera_configs)
240+
max_nv12 = max(camera_configs, key=lambda n: n.size)
241+
242+
_run = make_run_policy(vision_runner, policy_runner, model_w, model_h,
194243
vision_features_slice, frame_skip, prepare_only)
195244
run_policy_jit = TinyJit(_run, prune=True)
196245

197246
SEED = 42
198247

199-
def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_match=True):
248+
def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_match=True, camera_config=None):
200249
input_queues, npy = make_input_queues(vision_input_shapes, policy_input_shapes, frame_skip, Device.DEFAULT)
201250
np.random.seed(seed)
202251
Tensor.manual_seed(seed)
@@ -205,17 +254,19 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
205254
n_runs = 1 if testing else 3
206255

207256
for i in range(n_runs):
208-
frame = Tensor.randint(nv12.size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
209-
big_frame = Tensor.randint(nv12.size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
257+
nv12 = camera_config or camera_configs[0]
258+
camera_args = bind_camera_vars(camera_vars, nv12)
259+
frame = Tensor.randint(max_frame_size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
260+
big_frame = Tensor.randint(max_frame_size, low=0, high=256, dtype='uint8', device=WARP_DEV).realize()
210261
for v in npy.values():
211262
v[:] = np.random.randn(*v.shape).astype(v.dtype)
212263
Device.default.synchronize()
213264
st = time.perf_counter()
214-
outs = fn(**input_queues, frame=frame, big_frame=big_frame)
265+
outs = fn(**input_queues, frame=frame, big_frame=big_frame, **camera_args)
215266
mt = time.perf_counter()
216267
Device.default.synchronize()
217268
et = time.perf_counter()
218-
print(f" [{i+1}/{n_runs}] enqueue {(mt-st)*1e3:6.2f} ms -- total {(et-st)*1e3:6.2f} ms")
269+
print(f" [{i+1}/{n_runs}] {nv12.width}x{nv12.height} enqueue {(mt-st)*1e3:6.2f} ms -- total {(et-st)*1e3:6.2f} ms")
219270

220271
if i == 0:
221272
val = [np.copy(v.numpy()) for v in outs]
@@ -236,6 +287,11 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
236287
run_policy_jit = pickle.loads(pickle.dumps(run_policy_jit))
237288
random_inputs_run_fn(run_policy_jit, SEED, test_val, test_buffers, expect_match=True)
238289
random_inputs_run_fn(run_policy_jit, SEED+1, test_val, test_buffers, expect_match=False)
290+
for i, nv12 in enumerate(camera_configs[1:]):
291+
print(f'symbolic replay {nv12.width}x{nv12.height}')
292+
random_inputs_run_fn(run_policy_jit, SEED+2+i, camera_config=nv12)
293+
run_policy_jit.max_frame_size = max_frame_size
294+
run_policy_jit.max_camera_size = (max_nv12.width, max_nv12.height)
239295
return run_policy_jit
240296

241297

@@ -245,6 +301,8 @@ def _parse_size(s):
245301

246302

247303
def read_file_chunked_to_shm(path):
304+
from openpilot.system.hardware.hw import Paths
305+
248306
shm_path = os.path.join(Paths.shm_path(), os.path.basename(path))
249307
atexit.register(lambda: os.path.exists(shm_path) and os.remove(shm_path))
250308
with open(shm_path, 'wb') as f:
@@ -274,14 +332,15 @@ def read_file_chunked_to_shm(path):
274332
out['metadata']['vision'] = make_metadata_dict(vision_path)
275333
out['metadata']['policy'] = make_metadata_dict(policy_path)
276334

277-
for cam_w, cam_h in args.camera_resolutions:
278-
nv12 = NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h))
279-
model_w, model_h = args.model_size
280-
out[(cam_w,cam_h)] = {
281-
name: compile_modeld(nv12, model_w, model_h, prepare_only, args.frame_skip,
282-
vision_runner, policy_runner, out['metadata']['vision'], out['metadata']['policy'])
283-
for name, prepare_only in [('warp_enqueue', True), ('run_policy', False)]
284-
}
335+
camera_configs = [NV12Frame(cam_w, cam_h, *get_nv12_info(cam_w, cam_h)) for cam_w, cam_h in args.camera_resolutions]
336+
model_w, model_h = args.model_size
337+
out['camera_configs'] = {nv12[:2]: nv12 for nv12 in camera_configs}
338+
out['max_frame_size'] = max(nv12.size for nv12 in camera_configs)
339+
out['symbolic'] = {
340+
name: compile_modeld(camera_configs, model_w, model_h, prepare_only, args.frame_skip,
341+
vision_runner, policy_runner, out['metadata']['vision'], out['metadata']['policy'])
342+
for name, prepare_only in [('warp_enqueue', True), ('run_policy', False)]
343+
}
285344

286345
with open(args.output, "wb") as f:
287346
pickle.dump(out, f)

selfdrive/modeld/dmonitoringmodeld.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from openpilot.system.camerad.cameras.nv12_info import get_nv12_info
1717
from openpilot.common.file_chunker import read_file_chunked
1818
from openpilot.selfdrive.modeld.parse_model_outputs import sigmoid, safe_exp
19+
from openpilot.selfdrive.modeld.compile_modeld import NV12Frame, bind_camera_vars, make_camera_vars
1920

2021
PROCESS_NAME = "selfdrive.modeld.dmonitoringmodeld"
2122
SEND_RAW_PRED = os.getenv('SEND_RAW_PRED')
@@ -43,8 +44,20 @@ def __init__(self, cam_w: int, cam_h: int):
4344
self.tensor_inputs = {k: Tensor(v, device='NPY').realize() for k,v in self.numpy_inputs.items()}
4445
self._blob_cache : dict[int, Tensor] = {}
4546
self.model_run = pickle.loads(read_file_chunked(str(MODEL_PKL_PATH)))
46-
with open(MODELS_DIR / f'dm_warp_{cam_w}x{cam_h}_tinygrad.pkl', "rb") as f:
47-
self.image_warp = pickle.load(f)
47+
self.nv12 = NV12Frame(cam_w, cam_h, *self.frame_buf_params)
48+
dm_warp_path = MODELS_DIR / 'dm_warp_tinygrad.pkl'
49+
if dm_warp_path.is_file():
50+
with open(dm_warp_path, "rb") as f:
51+
dm_warp = pickle.load(f)
52+
self.image_warp = dm_warp['warp']
53+
self.max_frame_size = dm_warp['max_frame_size']
54+
self.camera_vars, _ = make_camera_vars(list(dm_warp['camera_configs'].values()))
55+
self.camera_args = bind_camera_vars(self.camera_vars, self.nv12)
56+
else:
57+
with open(MODELS_DIR / f'dm_warp_{cam_w}x{cam_h}_tinygrad.pkl', "rb") as f:
58+
self.image_warp = pickle.load(f)
59+
self.max_frame_size = self.frame_buf_params[3]
60+
self.camera_args = {}
4861

4962
def run(self, buf: VisionBuf, calib: np.ndarray, transform: np.ndarray) -> tuple[np.ndarray, float]:
5063
self.numpy_inputs['calib'][0,:] = calib
@@ -54,10 +67,10 @@ def run(self, buf: VisionBuf, calib: np.ndarray, transform: np.ndarray) -> tuple
5467
ptr = buf.data.ctypes.data
5568
# There is a ringbuffer of imgs, just cache tensors pointing to all of them
5669
if ptr not in self._blob_cache:
57-
self._blob_cache[ptr] = Tensor.from_blob(ptr, (self.frame_buf_params[3],), dtype='uint8', device=self.DEV)
70+
self._blob_cache[ptr] = Tensor.from_blob(ptr, (self.max_frame_size,), dtype='uint8', device=self.DEV)
5871

5972
self.warp_inputs_np['transform'][:] = transform[:]
60-
self.tensor_inputs['input_img'] = self.image_warp(self._blob_cache[ptr], self.warp_inputs['transform'])
73+
self.tensor_inputs['input_img'] = self.image_warp(self._blob_cache[ptr], self.warp_inputs['transform'], **self.camera_args)
6174

6275
output = self.model_run(**self.tensor_inputs).numpy().flatten()
6376

0 commit comments

Comments
 (0)