@@ -26,12 +26,12 @@ def fetch_fw(path, name, sha256):
2626_patch_tinygrad_fetch_fw ()
2727
2828from tinygrad .tensor import Tensor
29+ from tinygrad import Variable
2930from tinygrad .helpers import Context
3031from tinygrad .device import Device
3132from tinygrad .engine .jit import TinyJit
3233
3334from openpilot .common .file_chunker import read_file_chunked
34- from openpilot .system .hardware .hw import Paths
3535
3636
3737NV12Frame = namedtuple ("NV12Frame" , ['width' , 'height' , 'stride' , 'y_height' , 'uv_height' , 'size' ])
@@ -41,8 +41,61 @@ def fetch_fw(path, name, sha256):
4141
4242WARP_DEV = os .getenv ('WARP_DEV' )
4343
44+ _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = None
4445
45- def warp_perspective_tinygrad (src_flat , M_inv , dst_shape , src_shape , stride_pad , border_fill_val = None ):
46+
47+ def _optimize_local_size_or_skip (call , prg ):
48+ try :
49+ return _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE (call , prg )
50+ except AssertionError as e :
51+ if str (e ) != "all optimize_local_size exec failed" :
52+ raise
53+ return None
54+
55+
56+ def _patch_tinygrad_local_size_optimizer ():
57+ global _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE
58+ from tinygrad .engine import realize
59+ from tinygrad .uop .ops import Ops , PatternMatcher , UPat
60+
61+ _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = realize .optimize_local_size
62+ realize .pm_optimize_local_size = PatternMatcher ([
63+ (UPat (Ops .CALL , src = (UPat (Ops .PROGRAM , name = "prg" ),), name = "call" , allow_any_len = True ), _optimize_local_size_or_skip ),
64+ ])
65+
66+
67+ _patch_tinygrad_local_size_optimizer ()
68+
69+
70+ def make_camera_vars (camera_configs : list [NV12Frame ]):
71+ max_cam_w = max (nv12 .width for nv12 in camera_configs )
72+ max_cam_h = max (nv12 .height for nv12 in camera_configs )
73+ max_stride = max (nv12 .stride for nv12 in camera_configs )
74+ max_uv_offset = max (nv12 .stride * nv12 .y_height for nv12 in camera_configs )
75+ max_frame_size = max (nv12 .size for nv12 in camera_configs )
76+ return {
77+ 'cam_w' : Variable ('cam_w' , 1 , max_cam_w ),
78+ 'cam_h' : Variable ('cam_h' , 1 , max_cam_h ),
79+ 'chroma_w' : Variable ('chroma_w' , 1 , max_cam_w // 2 ),
80+ 'chroma_h' : Variable ('chroma_h' , 1 , max_cam_h // 2 ),
81+ 'stride' : Variable ('stride' , 1 , max_stride ),
82+ 'uv_offset' : Variable ('uv_offset' , 1 , max_uv_offset ),
83+ }, max_frame_size
84+
85+
86+ def bind_camera_vars (camera_vars , nv12 : NV12Frame ):
87+ values = {
88+ 'cam_w' : nv12 .width ,
89+ 'cam_h' : nv12 .height ,
90+ 'chroma_w' : nv12 .width // 2 ,
91+ 'chroma_h' : nv12 .height // 2 ,
92+ 'stride' : nv12 .stride ,
93+ 'uv_offset' : nv12 .stride * nv12 .y_height ,
94+ }
95+ return {k : v .bind (values [k ]) for k , v in camera_vars .items ()}
96+
97+
98+ def warp_perspective_tinygrad (src_flat , M_inv , dst_shape , src_shape , stride , src_offset = 0 , x_step = 1 , channel = 0 , border_fill_val = None ):
4699 w_dst , h_dst = dst_shape
47100 h_src , w_src = src_shape
48101
@@ -61,7 +114,7 @@ def warp_perspective_tinygrad(src_flat, M_inv, dst_shape, src_shape, stride_pad,
61114 y_round = Tensor .round (src_y )
62115 x_nn_clipped = x_round .clip (0 , w_src - 1 ).cast ('int' )
63116 y_nn_clipped = y_round .clip (0 , h_src - 1 ).cast ('int' )
64- idx = y_nn_clipped * ( w_src + stride_pad ) + x_nn_clipped
117+ idx = y_nn_clipped * stride + x_nn_clipped * x_step + src_offset + channel
65118 sampled = src_flat [idx ]
66119
67120 if border_fill_val is None :
@@ -84,26 +137,18 @@ def frames_to_tensor(frames):
84137 return in_img1
85138
86139
87- def make_frame_prepare (nv12 : NV12Frame , model_w , model_h ):
88- cam_w , cam_h , stride , y_height , uv_height , _ = nv12
89- uv_offset = stride * y_height
90- stride_pad = stride - cam_w
91-
92- def frame_prepare_tinygrad (input_frame , M_inv ):
140+ def make_frame_prepare (model_w , model_h ):
141+ def frame_prepare_tinygrad (input_frame , M_inv , cam_w , cam_h , chroma_w , chroma_h , stride , uv_offset ):
93142 # UV_SCALE @ M_inv @ UV_SCALE_INV simplifies to elementwise scaling
94143 M_inv_uv = M_inv * Tensor ([[1.0 , 1.0 , 0.5 ], [1.0 , 1.0 , 0.5 ], [2.0 , 2.0 , 1.0 ]], device = WARP_DEV )
95- # deinterleave NV12 UV plane (UVUV... -> separate U, V)
96- uv = input_frame [uv_offset :uv_offset + uv_height * stride ].reshape (uv_height , stride )
97144 with Context (SPLIT_REDUCEOP = 0 ):
98- y = warp_perspective_tinygrad (input_frame [:cam_h * stride ],
99- M_inv , (model_w , model_h ),
100- (cam_h , cam_w ), stride_pad ).realize ()
101- u = warp_perspective_tinygrad (uv [:cam_h // 2 , :cam_w :2 ].flatten (),
102- M_inv_uv , (model_w // 2 , model_h // 2 ),
103- (cam_h // 2 , cam_w // 2 ), 0 ).realize ()
104- v = warp_perspective_tinygrad (uv [:cam_h // 2 , 1 :cam_w :2 ].flatten (),
105- M_inv_uv , (model_w // 2 , model_h // 2 ),
106- (cam_h // 2 , cam_w // 2 ), 0 ).realize ()
145+ y = warp_perspective_tinygrad (input_frame , M_inv , (model_w , model_h ),
146+ (cam_h , cam_w ), stride ).realize ()
147+ # Gather directly from interleaved NV12 UV memory so symbolic widths avoid step=2 slicing.
148+ u = warp_perspective_tinygrad (input_frame , M_inv_uv , (model_w // 2 , model_h // 2 ),
149+ (chroma_h , chroma_w ), stride , uv_offset , x_step = 2 , channel = 0 ).realize ()
150+ v = warp_perspective_tinygrad (input_frame , M_inv_uv , (model_w // 2 , model_h // 2 ),
151+ (chroma_h , chroma_w ), stride , uv_offset , x_step = 2 , channel = 1 ).realize ()
107152 yuv = y .cat (u ).cat (v ).reshape ((model_h * 3 // 2 , model_w ))
108153 tensor = frames_to_tensor (yuv )
109154 return tensor
@@ -148,21 +193,22 @@ def sample_desire(buf, frame_skip):
148193 return buf .reshape (- 1 , frame_skip , * buf .shape [1 :]).max (1 ).flatten (0 , 1 ).unsqueeze (0 )
149194
150195
151- def make_run_policy (vision_runner , policy_runner , nv12 : NV12Frame , model_w , model_h ,
196+ def make_run_policy (vision_runner , policy_runner , model_w , model_h ,
152197 vision_features_slice , frame_skip , prepare_only = False ):
153- frame_prepare = make_frame_prepare (nv12 , model_w , model_h )
198+ frame_prepare = make_frame_prepare (model_w , model_h )
154199 sample_skip_fn = partial (sample_skip , frame_skip = frame_skip )
155200 sample_desire_fn = partial (sample_desire , frame_skip = frame_skip )
156201
157- def run_policy (img_q , big_img_q , feat_q , desire_q , desire , traffic_convention , tfm , big_tfm , frame , big_frame ):
202+ def run_policy (img_q , big_img_q , feat_q , desire_q , desire , traffic_convention , tfm , big_tfm ,
203+ frame , big_frame , cam_w , cam_h , chroma_w , chroma_h , stride , uv_offset ):
158204 tfm = tfm .to (WARP_DEV )
159205 big_tfm = big_tfm .to (WARP_DEV )
160206 desire = desire .to (Device .DEFAULT )
161207 traffic_convention = traffic_convention .to (Device .DEFAULT )
162208 Tensor .realize (tfm , big_tfm , desire , traffic_convention )
163209
164- warped_frame = frame_prepare (frame , tfm ).unsqueeze (0 ).to (Device .DEFAULT )
165- warped_big_frame = frame_prepare (big_frame , big_tfm ).unsqueeze (0 ).to (Device .DEFAULT )
210+ warped_frame = frame_prepare (frame , tfm , cam_w , cam_h , chroma_w , chroma_h , stride , uv_offset ).unsqueeze (0 ).to (Device .DEFAULT )
211+ warped_big_frame = frame_prepare (big_frame , big_tfm , cam_w , cam_h , chroma_w , chroma_h , stride , uv_offset ).unsqueeze (0 ).to (Device .DEFAULT )
166212 img = shift_and_sample (img_q , warped_frame , sample_skip_fn )
167213 big_img = shift_and_sample (big_img_q , warped_big_frame , sample_skip_fn )
168214
@@ -182,21 +228,24 @@ def run_policy(img_q, big_img_q, feat_q, desire_q, desire, traffic_convention, t
182228 return run_policy
183229
184230
185- def compile_modeld (nv12 : NV12Frame , model_w , model_h , prepare_only , frame_skip ,
231+ def compile_modeld (camera_configs : list [ NV12Frame ] , model_w , model_h , prepare_only , frame_skip ,
186232 vision_runner , policy_runner , vision_metadata , policy_metadata ):
187- print (f"Compiling combined policy JIT for { nv12 . width } x { nv12 . height } (prepare_only={ prepare_only } )..." )
233+ print (f"Compiling combined policy JIT for { len ( camera_configs ) } camera sizes (prepare_only={ prepare_only } )..." )
188234
189235 vision_features_slice = vision_metadata ['output_slices' ]['hidden_state' ]
190236 vision_input_shapes = vision_metadata ['input_shapes' ]
191237 policy_input_shapes = policy_metadata ['input_shapes' ]
192238
193- _run = make_run_policy (vision_runner , policy_runner , nv12 , model_w , model_h ,
239+ camera_vars , max_frame_size = make_camera_vars (camera_configs )
240+ max_nv12 = max (camera_configs , key = lambda n : n .size )
241+
242+ _run = make_run_policy (vision_runner , policy_runner , model_w , model_h ,
194243 vision_features_slice , frame_skip , prepare_only )
195244 run_policy_jit = TinyJit (_run , prune = True )
196245
197246 SEED = 42
198247
199- def random_inputs_run_fn (fn , seed , test_val = None , test_buffers = None , expect_match = True ):
248+ def random_inputs_run_fn (fn , seed , test_val = None , test_buffers = None , expect_match = True , camera_config = None ):
200249 input_queues , npy = make_input_queues (vision_input_shapes , policy_input_shapes , frame_skip , Device .DEFAULT )
201250 np .random .seed (seed )
202251 Tensor .manual_seed (seed )
@@ -205,17 +254,19 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
205254 n_runs = 1 if testing else 3
206255
207256 for i in range (n_runs ):
208- frame = Tensor .randint (nv12 .size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
209- big_frame = Tensor .randint (nv12 .size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
257+ nv12 = camera_config or camera_configs [0 ]
258+ camera_args = bind_camera_vars (camera_vars , nv12 )
259+ frame = Tensor .randint (max_frame_size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
260+ big_frame = Tensor .randint (max_frame_size , low = 0 , high = 256 , dtype = 'uint8' , device = WARP_DEV ).realize ()
210261 for v in npy .values ():
211262 v [:] = np .random .randn (* v .shape ).astype (v .dtype )
212263 Device .default .synchronize ()
213264 st = time .perf_counter ()
214- outs = fn (** input_queues , frame = frame , big_frame = big_frame )
265+ outs = fn (** input_queues , frame = frame , big_frame = big_frame , ** camera_args )
215266 mt = time .perf_counter ()
216267 Device .default .synchronize ()
217268 et = time .perf_counter ()
218- print (f" [{ i + 1 } /{ n_runs } ] enqueue { (mt - st )* 1e3 :6.2f} ms -- total { (et - st )* 1e3 :6.2f} ms" )
269+ print (f" [{ i + 1 } /{ n_runs } ] { nv12 . width } x { nv12 . height } enqueue { (mt - st )* 1e3 :6.2f} ms -- total { (et - st )* 1e3 :6.2f} ms" )
219270
220271 if i == 0 :
221272 val = [np .copy (v .numpy ()) for v in outs ]
@@ -236,6 +287,11 @@ def random_inputs_run_fn(fn, seed, test_val=None, test_buffers=None, expect_matc
236287 run_policy_jit = pickle .loads (pickle .dumps (run_policy_jit ))
237288 random_inputs_run_fn (run_policy_jit , SEED , test_val , test_buffers , expect_match = True )
238289 random_inputs_run_fn (run_policy_jit , SEED + 1 , test_val , test_buffers , expect_match = False )
290+ for i , nv12 in enumerate (camera_configs [1 :]):
291+ print (f'symbolic replay { nv12 .width } x{ nv12 .height } ' )
292+ random_inputs_run_fn (run_policy_jit , SEED + 2 + i , camera_config = nv12 )
293+ run_policy_jit .max_frame_size = max_frame_size
294+ run_policy_jit .max_camera_size = (max_nv12 .width , max_nv12 .height )
239295 return run_policy_jit
240296
241297
@@ -245,6 +301,8 @@ def _parse_size(s):
245301
246302
247303def read_file_chunked_to_shm (path ):
304+ from openpilot .system .hardware .hw import Paths
305+
248306 shm_path = os .path .join (Paths .shm_path (), os .path .basename (path ))
249307 atexit .register (lambda : os .path .exists (shm_path ) and os .remove (shm_path ))
250308 with open (shm_path , 'wb' ) as f :
@@ -274,14 +332,15 @@ def read_file_chunked_to_shm(path):
274332 out ['metadata' ]['vision' ] = make_metadata_dict (vision_path )
275333 out ['metadata' ]['policy' ] = make_metadata_dict (policy_path )
276334
277- for cam_w , cam_h in args .camera_resolutions :
278- nv12 = NV12Frame (cam_w , cam_h , * get_nv12_info (cam_w , cam_h ))
279- model_w , model_h = args .model_size
280- out [(cam_w ,cam_h )] = {
281- name : compile_modeld (nv12 , model_w , model_h , prepare_only , args .frame_skip ,
282- vision_runner , policy_runner , out ['metadata' ]['vision' ], out ['metadata' ]['policy' ])
283- for name , prepare_only in [('warp_enqueue' , True ), ('run_policy' , False )]
284- }
335+ camera_configs = [NV12Frame (cam_w , cam_h , * get_nv12_info (cam_w , cam_h )) for cam_w , cam_h in args .camera_resolutions ]
336+ model_w , model_h = args .model_size
337+ out ['camera_configs' ] = {nv12 [:2 ]: nv12 for nv12 in camera_configs }
338+ out ['max_frame_size' ] = max (nv12 .size for nv12 in camera_configs )
339+ out ['symbolic' ] = {
340+ name : compile_modeld (camera_configs , model_w , model_h , prepare_only , args .frame_skip ,
341+ vision_runner , policy_runner , out ['metadata' ]['vision' ], out ['metadata' ]['policy' ])
342+ for name , prepare_only in [('warp_enqueue' , True ), ('run_policy' , False )]
343+ }
285344
286345 with open (args .output , "wb" ) as f :
287346 pickle .dump (out , f )
0 commit comments