1+ import dask
2+ import numpy as np
3+ import cupy as cp
4+ import tvm
5+ import jax
6+ import torch
7+ from scalable_integration .utils import get_chunks
8+ from scalable_integration .custom_worker import get_operator
9+ jax .config .update ("jax_enable_x64" , True )
10+
11+
12+ def complex_trace_base (
13+ rt ,
14+ input_data ,
15+ output_data ,
16+ chunksize ,
17+ overlap ,
18+ dtype ,
19+ device ,
20+ ):
21+
22+ in_ind , out_ind , padding = get_chunks (
23+ data_shape = input_data .shape ,
24+ chunksize = chunksize ,
25+ overlap = overlap
26+ )
27+
28+ task_funcs = {
29+ "tvm" : complex_trace_tvm ,
30+ "baseline" : complex_trace_baseline ,
31+ "torch_c" : complex_trace_torch_c ,
32+ "torch_n" : complex_trace_torch_n ,
33+ "jax" : complex_trace_jax ,
34+ }
35+
36+ task_func = task_funcs [rt ]
37+
38+ tasks = [
39+ task_func (
40+ input_data = input_data ,
41+ output_data = output_data ,
42+ indx = i ,
43+ out_indx = out_i ,
44+ chunksize = chunksize ,
45+ pad_width = p ,
46+ overlap = overlap ,
47+ dtype = dtype ,
48+ device = device ,
49+ )
50+ for i , out_i , p in zip (in_ind , out_ind , padding )
51+ ]
52+
53+ return tasks , chunksize
54+
55+
56+
57+ @dask .delayed
58+ def complex_trace_tvm (
59+ input_data ,
60+ output_data ,
61+ indx ,
62+ out_indx ,
63+ chunksize ,
64+ pad_width ,
65+ overlap ,
66+ dtype ,
67+ device
68+ ):
69+ sli = tuple (
70+ slice (i , i + c + o [0 ] + o [1 ] - p [0 ] - p [1 ])
71+ for i , c , p , o in zip (indx , chunksize , pad_width , overlap )
72+ )
73+ chunk = input_data [sli ].astype (dtype )
74+ chunk = np .pad (
75+ chunk , pad_width = pad_width , mode = "constant" , constant_values = 0
76+ )
77+
78+ operator = get_operator ()
79+
80+ data_tvm = tvm .nd .array (chunk , device = operator ._dev )
81+ res = tvm .nd .empty (data_tvm .shape , dtype = data_tvm .dtype , device = operator ._dev )
82+ operator .transform (data_tvm , res )
83+
84+ res = res .numpy ()
85+
86+
87+ useful_slice = [
88+ min (c , o - i )
89+ for i , c , o in zip (out_indx , chunksize , output_data .shape )
90+ ]
91+ out_sli = tuple (
92+ slice (i , i + u )
93+ for i , u in zip (out_indx , useful_slice )
94+ )
95+
96+ res_sli = tuple (
97+ slice (o [0 ], o [0 ] + u )
98+ for u , o in zip (useful_slice , overlap )
99+ )
100+
101+ output_data [out_sli ] = res [res_sli ]
102+
103+ @dask .delayed
104+ def complex_trace_baseline (
105+ input_data ,
106+ output_data ,
107+ indx ,
108+ out_indx ,
109+ chunksize ,
110+ pad_width ,
111+ overlap ,
112+ dtype ,
113+ device
114+ ):
115+ sli = tuple (
116+ slice (i , i + c + o [0 ] + o [1 ] - p [0 ] - p [1 ])
117+ for i , c , p , o in zip (indx , chunksize , pad_width , overlap )
118+ )
119+ chunk = input_data [sli ].astype (dtype )
120+ chunk = np .pad (
121+ chunk , pad_width = pad_width , mode = "constant" , constant_values = 0
122+ )
123+ operator = get_operator ()
124+
125+ if device == "cpu" :
126+ res = operator ._transform_cpu (chunk )
127+ else :
128+ chunk = cp .asarray (chunk )
129+ res = operator ._transform_gpu (chunk ).get ()
130+
131+
132+ useful_slice = [
133+ min (c , o - i )
134+ for i , c , o in zip (out_indx , chunksize , output_data .shape )
135+ ]
136+ out_sli = tuple (
137+ slice (i , i + u )
138+ for i , u in zip (out_indx , useful_slice )
139+ )
140+
141+ res_sli = tuple (
142+ slice (o [0 ], o [0 ] + u )
143+ for u , o in zip (useful_slice , overlap )
144+ )
145+
146+ output_data [out_sli ] = res [res_sli ]
147+
148+ @dask .delayed
149+ def complex_trace_jax (
150+ input_data ,
151+ output_data ,
152+ indx ,
153+ out_indx ,
154+ chunksize ,
155+ pad_width ,
156+ overlap ,
157+ dtype ,
158+ device
159+ ):
160+ sli = tuple (
161+ slice (i , i + c + o [0 ] + o [1 ] - p [0 ] - p [1 ])
162+ for i , c , p , o in zip (indx , chunksize , pad_width , overlap )
163+ )
164+ chunk = input_data [sli ].astype (dtype )
165+ chunk = np .pad (
166+ chunk , pad_width = pad_width , mode = "constant" , constant_values = 0
167+ )
168+ chunk = jax .device_put (chunk , device = jax .devices (device )[0 ])
169+
170+ operator = get_operator ()
171+ if device == "cpu" :
172+ res = operator ._transform_cpu (chunk )
173+ else :
174+ res = operator ._transform_gpu (chunk )
175+ res = np .asarray (res )
176+
177+
178+ useful_slice = [
179+ min (c , o - i )
180+ for i , c , o in zip (out_indx , chunksize , output_data .shape )
181+ ]
182+ out_sli = tuple (
183+ slice (i , i + u )
184+ for i , u in zip (out_indx , useful_slice )
185+ )
186+
187+ res_sli = tuple (
188+ slice (o [0 ], o [0 ] + u )
189+ for u , o in zip (useful_slice , overlap )
190+ )
191+
192+ output_data [out_sli ] = res [res_sli ]
193+
194+ @dask .delayed
195+ def complex_trace_torch_c (
196+ input_data ,
197+ output_data ,
198+ indx ,
199+ out_indx ,
200+ chunksize ,
201+ pad_width ,
202+ overlap ,
203+ dtype ,
204+ device
205+ ):
206+ sli = tuple (
207+ slice (i , i + c + o [0 ] + o [1 ] - p [0 ] - p [1 ])
208+ for i , c , p , o in zip (indx , chunksize , pad_width , overlap )
209+ )
210+ chunk = input_data [sli ].astype (dtype )
211+ chunk = np .pad (
212+ chunk , pad_width = pad_width , mode = "constant" , constant_values = 0
213+ )
214+ chunk = torch .from_numpy (chunk ).to (torch .device ("cpu" if device == "cpu" else "cuda" ))
215+ operator = get_operator ()
216+ if device == "cpu" :
217+ res = operator ._transform_cpu (chunk )
218+ else :
219+ res = operator ._transform_gpu (chunk ).cpu ()
220+ res = res .numpy ()
221+
222+
223+ useful_slice = [
224+ min (c , o - i )
225+ for i , c , o in zip (out_indx , chunksize , output_data .shape )
226+ ]
227+ out_sli = tuple (
228+ slice (i , i + u )
229+ for i , u in zip (out_indx , useful_slice )
230+ )
231+
232+ res_sli = tuple (
233+ slice (o [0 ], o [0 ] + u )
234+ for u , o in zip (useful_slice , overlap )
235+ )
236+
237+ output_data [out_sli ] = res [res_sli ]
238+
239+
240+ @dask .delayed
241+ def complex_trace_torch_n (
242+ input_data ,
243+ output_data ,
244+ indx ,
245+ out_indx ,
246+ chunksize ,
247+ pad_width ,
248+ overlap ,
249+ dtype ,
250+ device
251+ ):
252+ sli = tuple (
253+ slice (i , i + c + o [0 ] + o [1 ] - p [0 ] - p [1 ])
254+ for i , c , p , o in zip (indx , chunksize , pad_width , overlap )
255+ )
256+ chunk = input_data [sli ].astype (dtype )
257+ chunk = np .pad (
258+ chunk , pad_width = pad_width , mode = "constant" , constant_values = 0
259+ )
260+ chunk = torch .from_numpy (chunk ).to (torch .device ("cpu" if device == "cpu" else "cuda" ))
261+ operator = get_operator ()
262+ if device == "cpu" :
263+ res = operator ._nocompile_cpu (chunk )
264+ else :
265+ res = operator ._nocompile_gpu (chunk ).cpu ()
266+ res = res .numpy ()
267+
268+
269+ useful_slice = [
270+ min (c , o - i )
271+ for i , c , o in zip (out_indx , chunksize , output_data .shape )
272+ ]
273+ out_sli = tuple (
274+ slice (i , i + u )
275+ for i , u in zip (out_indx , useful_slice )
276+ )
277+
278+ res_sli = tuple (
279+ slice (o [0 ], o [0 ] + u )
280+ for u , o in zip (useful_slice , overlap )
281+ )
282+
283+ output_data [out_sli ] = res [res_sli ]
0 commit comments