|
1 | 1 | import jax
|
2 | 2 | import jax.numpy as jnp
|
3 | 3 |
|
| 4 | + |
4 | 5 | def cubic_kernel(x, a=-0.75):
|
5 |
| - """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" |
6 |
| - absx = jnp.abs(x) |
7 |
| - x2 = absx * absx |
8 |
| - x3 = x2 * absx |
9 |
| - cond1 = (absx <= 1) |
10 |
| - cond2 = (absx > 1) & (absx < 2) |
11 |
| - f1 = (a + 2) * x3 - (a + 3) * x2 + 1 |
12 |
| - f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a |
13 |
| - return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) |
14 |
| - |
15 |
| - |
16 |
| -def compute_contribs(in_size, out_size, scale, support=2.0, align_corners=False): |
17 |
| - if align_corners: |
18 |
| - if out_size == 1: |
19 |
| - in_coords = jnp.zeros((1,)) |
20 |
| - else: |
21 |
| - in_coords = jnp.linspace(0, in_size - 1, out_size) |
| 6 | + """Cubic kernel with a = -0.75 (PyTorch-like Keys kernel)""" |
| 7 | + absx = jnp.abs(x) |
| 8 | + x2 = absx * absx |
| 9 | + x3 = x2 * absx |
| 10 | + cond1 = (absx <= 1) |
| 11 | + cond2 = (absx > 1) & (absx < 2) |
| 12 | + f1 = (a + 2) * x3 - (a + 3) * x2 + 1 |
| 13 | + f2 = a * x3 - 5 * a * x2 + 8 * a * absx - 4 * a |
| 14 | + return jnp.where(cond1, f1, jnp.where(cond2, f2, 0.0)) |
| 15 | + |
| 16 | + |
| 17 | +def compute_contribs(in_size, |
| 18 | + out_size, |
| 19 | + scale, |
| 20 | + support=2.0, |
| 21 | + align_corners=False, |
| 22 | + dtype=None): |
| 23 | + if align_corners: |
| 24 | + if out_size == 1: |
| 25 | + in_coords = jnp.zeros((1,), dtype=dtype) |
22 | 26 | else:
|
23 |
| - out_coords = jnp.arange(out_size) + 0.5 |
24 |
| - in_coords = out_coords / scale - 0.5 |
| 27 | + in_coords = jnp.linspace(0, in_size - 1, out_size, dtype=dtype) |
| 28 | + else: |
| 29 | + out_coords = jnp.arange(out_size, dtype=dtype) + 0.5 |
| 30 | + in_coords = out_coords / scale - 0.5 |
| 31 | + |
| 32 | + left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 |
| 33 | + idxs = left_idx[:, None] + jnp.arange(4) |
25 | 34 |
|
26 |
| - left_idx = jnp.floor(in_coords).astype(jnp.int32) - 1 |
27 |
| - idxs = left_idx[:, None] + jnp.arange(4) |
| 35 | + dx = in_coords[:, None] - idxs |
28 | 36 |
|
29 |
| - dx = in_coords[:, None] - idxs |
| 37 | + weights = cubic_kernel(dx) |
30 | 38 |
|
31 |
| - weights = cubic_kernel(dx) |
| 39 | + weights = weights / jnp.sum(weights, axis=1, keepdims=True) |
| 40 | + return idxs, weights |
32 | 41 |
|
33 |
| - weights = weights / jnp.sum(weights, axis=1, keepdims=True) |
34 |
| - return idxs, weights |
35 | 42 |
|
36 | 43 | def gather_weights(img, idxs, axis):
|
37 |
| - """Safely gather with boundary handling""" |
38 |
| - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) |
39 |
| - return jnp.take(img, idxs, axis=axis) |
| 44 | + """Safely gather with boundary handling""" |
| 45 | + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) |
| 46 | + return jnp.take(img, idxs, axis=axis) |
| 47 | + |
40 | 48 |
|
41 | 49 | def interpolate_along_axis_bchw(img, idxs, weights, axis):
|
42 |
| - """ |
| 50 | + """ |
43 | 51 | Interpolate along H (axis=2) or W (axis=3) for tensor (B, C, H, W).
|
44 | 52 | idxs: (out_size, 4) int32 indices
|
45 | 53 | weights: (out_size, 4) float32 weights
|
46 | 54 | """
|
47 |
| - assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" |
48 |
| - out_size = idxs.shape[0] |
49 |
| - k = idxs.shape[1] # Typically 4 for cubic |
| 55 | + assert axis in (2, 3), "Axis must be 2 (H) or 3 (W)" |
| 56 | + out_size = idxs.shape[0] |
| 57 | + k = idxs.shape[1] # Typically 4 for cubic |
50 | 58 |
|
51 |
| - # Clip to input bounds |
52 |
| - idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) |
| 59 | + # Clip to input bounds |
| 60 | + idxs = jnp.clip(idxs, 0, img.shape[axis] - 1) # (out_size, 4) |
53 | 61 |
|
54 |
| - def gather_and_weight(i): |
55 |
| - idx = idxs[i] # (4,) |
56 |
| - w = weights[i] # (4,) |
| 62 | + def gather_and_weight(i): |
| 63 | + idx = idxs[i] # (4,) |
| 64 | + w = weights[i] # (4,) |
57 | 65 |
|
58 |
| - def gather_one(offset): |
59 |
| - return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) |
| 66 | + def gather_one(offset): |
| 67 | + return jnp.take(img, idx[offset], axis=axis) # shape (B, C, H, W) |
60 | 68 |
|
61 |
| - gathered = jnp.stack([gather_one(o) for o in range(k)], axis=0) # (4, B, C, H, W) |
62 |
| - weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) |
63 |
| - return weighted |
| 69 | + gathered = jnp.stack([gather_one(o) for o in range(k)], |
| 70 | + axis=0) # (4, B, C, H, W) |
| 71 | + weighted = jnp.tensordot(w, gathered, axes=(0, 0)) # (B, C, H, W) |
| 72 | + return weighted |
64 | 73 |
|
65 |
| - out = jax.vmap(gather_and_weight)(jnp.arange(out_size)) # (out_size, B, C, H, W) |
66 |
| - |
67 |
| - # Move the interpolated axis back into place |
68 |
| - if axis == 2: # interpolated over H |
69 |
| - return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) |
70 |
| - else: # axis == 3, interpolated over W |
71 |
| - return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) |
| 74 | + out = jax.vmap(gather_and_weight)( |
| 75 | + jnp.arange(out_size)) # (out_size, B, C, H, W) |
72 | 76 |
|
| 77 | + # Move the interpolated axis back into place |
| 78 | + if axis == 2: # interpolated over H |
| 79 | + return jnp.moveaxis(out, 0, 2) # (B, C, out_H, W) |
| 80 | + else: # axis == 3, interpolated over W |
| 81 | + return jnp.moveaxis(out, 0, 3) # (B, C, H, out_W) |
73 | 82 |
|
74 | 83 |
|
75 | 84 | def interpolate_bicubic_no_aa(img, out_h, out_w, align_corners=False):
|
76 |
| - h, w = img.shape[-2:] |
77 |
| - if align_corners and out_h > 1: |
78 |
| - scale_y = (h - 1) / (out_h - 1) |
79 |
| - else: |
80 |
| - scale_y = out_h / h |
81 |
| - |
82 |
| - if align_corners and out_w > 1: |
83 |
| - scale_x = (w - 1) / (out_w - 1) |
84 |
| - else: |
85 |
| - scale_x = out_w / w |
86 |
| - |
87 |
| - idxs_y, weights_y = compute_contribs( |
88 |
| - h, out_h, scale_y, align_corners=align_corners, |
89 |
| - ) |
90 |
| - tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) |
91 |
| - |
92 |
| - idxs_x, weights_x = compute_contribs( |
93 |
| - w, out_w, scale_x, align_corners=align_corners, |
94 |
| - ) |
95 |
| - out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) |
96 |
| - return out |
| 85 | + h, w = img.shape[-2:] |
| 86 | + if align_corners and out_h > 1: |
| 87 | + scale_y = (h - 1) / (out_h - 1) |
| 88 | + else: |
| 89 | + scale_y = out_h / h |
| 90 | + |
| 91 | + if align_corners and out_w > 1: |
| 92 | + scale_x = (w - 1) / (out_w - 1) |
| 93 | + else: |
| 94 | + scale_x = out_w / w |
| 95 | + |
| 96 | + idxs_y, weights_y = compute_contribs( |
| 97 | + h, |
| 98 | + out_h, |
| 99 | + scale_y, |
| 100 | + align_corners=align_corners, |
| 101 | + dtype=img.dtype, |
| 102 | + ) |
| 103 | + tmp = interpolate_along_axis_bchw(img, idxs_y, weights_y, axis=2) |
| 104 | + |
| 105 | + idxs_x, weights_x = compute_contribs( |
| 106 | + w, |
| 107 | + out_w, |
| 108 | + scale_x, |
| 109 | + align_corners=align_corners, |
| 110 | + dtype=img.dtype, |
| 111 | + ) |
| 112 | + out = interpolate_along_axis_bchw(tmp, idxs_x, weights_x, axis=3) |
| 113 | + return out |
0 commit comments