@@ -335,18 +335,20 @@ def _conv_transpose2d(
335335 if mem_stack :
336336 tmp_weighted_shape = (b , co , ho , wo )
337337 # The trouble here is that we allocate more than the returned size
338- out_actual_bytes = np .prod (out_shape ) * np .float32 ().itemsize
339- mem_stack .malloc (out_actual_bytes )
338+ mem_stack .malloc (np .prod (out_shape ) * np .float32 ().itemsize )
340339 mem_stack .malloc ((np .prod (tmp_weighted_shape ) + w .size ) * np .float32 ().itemsize )
341340 mem_stack .free ((np .prod (tmp_weighted_shape ) + w .size ) * np .float32 ().itemsize )
342341 if pad != 0 :
343- return [
342+ new_out_shape = [
344343 out_shape [0 ],
345344 out_shape [1 ],
346345 out_shape [2 ] - 2 * pad [0 ],
347346 out_shape [3 ] - 2 * pad [1 ],
348- ], out_actual_bytes
349- return out_shape , out_actual_bytes
347+ ]
348+ mem_stack .malloc (np .prod (new_out_shape ) * np .float32 ().itemsize )
349+ mem_stack .free (np .prod (out_shape ) * np .float32 ().itemsize )
350+ out_shape = new_out_shape
351+ return out_shape
350352
351353 out = cp .zeros (out_shape , dtype = "float32" )
352354 w = cp .asarray (w )
@@ -365,7 +367,7 @@ def _conv_transpose2d(
365367 )
366368 if pad != 0 :
367369 out = out [:, :, pad [0 ] : out .shape [2 ] - pad [0 ], pad [1 ] : out .shape [3 ] - pad [1 ]]
368- return out , None
370+ return cp . ascontiguousarray ( out )
369371
370372
371373def _afb1d (
@@ -434,17 +436,17 @@ def _sfb1d(
434436 g0 = np .concatenate ([g0 .reshape (* shape )] * C , axis = 0 )
435437 g1 = np .concatenate ([g1 .reshape (* shape )] * C , axis = 0 )
436438 pad = (L - 2 , 0 ) if d == 2 else (0 , L - 2 )
437- y_lo , y_lo_alloc_bytes = _conv_transpose2d (
439+ y_lo = _conv_transpose2d (
438440 lo , g0 , stride = s , pad = pad , groups = C , mem_stack = mem_stack
439441 )
440- y_hi , y_hi_alloc_bytes = _conv_transpose2d (
442+ y_hi = _conv_transpose2d (
441443 hi , g1 , stride = s , pad = pad , groups = C , mem_stack = mem_stack
442444 )
443445 if mem_stack :
444446 # Allocation of the sum
445447 mem_stack .malloc (np .prod (y_hi ) * np .float32 ().itemsize )
446- mem_stack .free (y_lo_alloc_bytes )
447- mem_stack .free (y_hi_alloc_bytes )
448+ mem_stack .free (np . prod ( y_lo ) * np . float32 (). itemsize )
449+ mem_stack .free (np . prod ( y_hi ) * np . float32 (). itemsize )
448450 return y_lo
449451 return y_lo + y_hi
450452
@@ -635,7 +637,10 @@ def remove_stripe_fw(
635637 # For the FFT
636638 mem_stack .malloc (2 * fcV_bytes )
637639 # This is "leaked" by the FFT
638- fcV_fft_bytes = fcV_shape [0 ] * fcV_shape [2 ] * np .complex64 ().itemsize
640+ if fcV_shape [1 ] > 150 :
641+ fcV_fft_bytes = fcV_bytes
642+ else :
643+ fcV_fft_bytes = fcV_shape [0 ] * fcV_shape [2 ] * np .complex64 ().itemsize
639644 mem_stack .malloc (fcV_fft_bytes )
640645 mem_stack .free (2 * fcV_bytes )
641646
0 commit comments