@@ -40,36 +40,6 @@ def fetch_fw(path, name, sha256):
4040
4141WARP_DEV = os .getenv ('WARP_DEV' )
4242
43- _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = None
44-
45-
46- def _optimize_local_size_or_skip (call , prg ):
47- try :
48- return _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE (call , prg )
49- except AssertionError as e :
50- if str (e ) != "all optimize_local_size exec failed" :
51- raise
52- from dataclasses import replace
53- preferred = (32 , 16 , 1 )
54- local_size = tuple (next (x for x in range (min (preferred [i ] if i < len (preferred ) else 1 , g ), 0 , - 1 ) if g % x == 0 )
55- for i , g in enumerate (prg .arg .global_size ))
56- new_global = tuple (g // l if g % l == 0 else g / l for g , l in zip (prg .arg .global_size , local_size , strict = True ))
57- return call .replace (src = (prg .replace (arg = replace (prg .arg , global_size = new_global , local_size = local_size )), * call .src [1 :]))
58-
59-
60- def _patch_tinygrad_local_size_optimizer ():
61- global _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE
62- from tinygrad .engine import realize
63- from tinygrad .uop .ops import Ops , PatternMatcher , UPat
64-
65- _ORIG_TINYGRAD_OPTIMIZE_LOCAL_SIZE = realize .optimize_local_size
66- realize .pm_optimize_local_size = PatternMatcher ([
67- (UPat (Ops .CALL , src = (UPat (Ops .PROGRAM , name = "prg" ),), name = "call" , allow_any_len = True ), _optimize_local_size_or_skip ),
68- ])
69-
70-
71- _patch_tinygrad_local_size_optimizer ()
72-
7343
7444def make_camera_vars (camera_configs : list [NV12Frame ]):
7545 max_cam_w = max (nv12 .width for nv12 in camera_configs )
0 commit comments