-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathoptimize.py
More file actions
217 lines (187 loc) · 9.85 KB
/
Copy pathoptimize.py
File metadata and controls
217 lines (187 loc) · 9.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""optimize.py -- THE ONLY FILE THE AGENT MAY EDIT.
Contract (do NOT change the signature):
build_decoder(device="cuda") -> (decode_fn, info)
decode_fn(latent_cpu_f32) : Tensor[48, T, H, W] (CPU, float32) -> Tensor[3, T', H', W'] on `device`
info : dict[str, str] human-readable config description
Everything the harness measures flows through decode_fn. To make the decode faster,
change what happens INSIDE build_decoder / decode_fn. Keep the contract.
Current state: exp_000 -- the honest NAIVE BASELINE for the Wan 2.2 VAE decoder:
fp32, eager, no torch.compile, no patches. This is intentionally the naive starting
point so the auto-optim loop has the full climb to demonstrate. The gate compares
against the frozen fp32 reference. The loop climbs from here.
"""
import torch
from harness._common import CKPT_PATH
def build_decoder(device="cuda"):
info = {}
# --- BEGIN EDIT ZONE -----------------------------------------------------
# exp_008: prev best (bf16 + channels_last_3d + bf16-upsample) + full-decode CUDA
# graph. After the upsample fix the profile flipped to HOST-BOUND (GPU util 65%):
# the 20-frame streaming loop issues thousands of tiny kernels and the GPU starves on
# Python/launch overhead. The control flow is data-independent (fixed 20 iters,
# first_chunk=i==0) with no .item()/CPU syncs, so the whole decode is CUDA-graph
# capturable. Capture once (after side-stream warmup), then replay per call: copy the
# new latent into the static input, replay, return a clone of the static output. This
# collapses the entire decode to one launch and recovers the idle time.
import types
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from model import get_wan_vae
from model.vae2_2 import (CACHE_T, AttentionBlock, CausalConv3d, Decoder3d,
Resample, ResidualBlock, Up_ResidualBlock, Upsample,
unpatchify)
dtype = torch.bfloat16
info["dtype"] = "bfloat16"
info["compile"] = "none (eager)"
info["patches"] = ("channels_last_3d + bf16-upsample + native-spatial-pad conv "
"+ full-decode CUDA graph")
vae = get_wan_vae(CKPT_PATH, device=device, dtype=dtype)
for m in vae.model.decoder.modules():
if isinstance(m, nn.Conv3d):
m.to(memory_format=torch.channels_last_3d)
vae.model.conv2.to(memory_format=torch.channels_last_3d)
# Skip the fp32 round-trip in nearest-exact upsample (identical values in bf16).
for m in vae.model.decoder.modules():
if isinstance(m, Upsample):
m.forward = (lambda x, _u=m: nn.Upsample.forward(_u, x))
# #2 cost in the profile (~6.4s direct_copy) is CausalConv3d.forward materializing a
# FULL padded copy of every activation via F.pad (it zeroes the conv's own padding and
# pads spatial+temporal by hand). Let cuDNN pad the spatial dims natively (free, inside
# the conv) and only F.pad the tiny causal-temporal-left dim. Mathematically identical
# (gate-validated); all decoder convs are stride-1.
def fast_causal_forward(self, x, cache_x=None):
t_left = self._padding[4] # 2*pad_t on the causal (left) side; right is 0
if cache_x is not None and t_left > 0:
cache_x = cache_x.to(x.device)
x = torch.cat([cache_x, x], dim=2)
t_left -= cache_x.shape[2]
if t_left > 0:
x = F.pad(x, (0, 0, 0, 0, t_left, 0))
return F.conv3d(x, self.weight, self.bias, self.stride,
self._spatial_pad, self.dilation, self.groups)
for m in list(vae.model.decoder.modules()) + [vae.model.conv2]:
if isinstance(m, CausalConv3d):
assert tuple(m.stride) == (1, 1, 1), m.stride
# _padding = (pw, pw, ph, ph, 2*pt, 0); native conv padding = (0, ph, pw)
m._spatial_pad = (0, m._padding[2], m._padding[0])
m._cx = None # per-module streaming cache (replaces the external feat_cache list)
m.forward = types.MethodType(fast_causal_forward, m)
# exp_014: break-free decoder for full inductor fusion. The remaining big eager cost is
# the residual `add` (x + h): inductor leaves it eager because the model's decoder forwards
# graph-break at the external feat_cache list ops + the "Rep" STRING sentinel (dynamo:
# "Unsupported method call" / "non-Tensor"). Reimplement the decoder forwards here with a
# per-module cache (tensor attr + int state, no list/string) so the decoder compiles
# BREAK-FREE -> inductor fuses add->norm->silu chains across the whole block. The cache
# math is replicated EXACTLY (gate-validated bitwise vs the streaming reference).
def cached_conv(conv, x):
# Replicates the parent's standard feat_cache logic, but per-module (conv._cx).
prev = conv._cx
cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and prev is not None:
cache_x = torch.cat([prev[:, :, -1:, :, :], cache_x], dim=2)
else:
cache_x = cache_x.clone()
out = conv(x, prev)
conv._cx = cache_x
return out
def residual_forward(self, x):
h = self.shortcut(x) # shortcut conv is called WITHOUT cache (plain), as in the model
for layer in self.residual:
x = cached_conv(layer, x) if isinstance(layer, CausalConv3d) else layer(x)
return x + h
def resample_forward(self, x, first_chunk=False):
b, c, t, h, w = x.size()
if self.mode == "upsample3d":
st = self._fc_state # 0 fresh ("Rep"), 1 was-Rep, 2 has tensor cache
if st == 0:
self._fc_state = 1
else:
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if st == 2 and cache_x.shape[2] < 2:
cache_x = torch.cat([self._tcx[:, :, -1:, :, :], cache_x], dim=2)
if st == 1 and cache_x.shape[2] < 2:
cache_x = torch.cat([torch.zeros_like(cache_x), cache_x], dim=2)
x = self.time_conv(x) if st == 1 else self.time_conv(x, self._tcx)
self._tcx = cache_x
self._fc_state = 2
x = x.reshape(b, 2, c, t, h, w)
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
x = x.reshape(b, c, t * 2, h, w)
t = x.shape[2]
x = rearrange(x, "b c t h w -> (b t) c h w")
x = self.resample(x)
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
return x
def up_residual_forward(self, x, first_chunk=False):
x_main = x.clone()
for module in self.upsamples:
x_main = module(x_main) # Resample uses its own _fc_state; ResidualBlock self-caches
if self.avg_shortcut is not None:
return x_main + self.avg_shortcut(x, first_chunk)
return x_main
def decoder_forward(self, x, feat_cache=None, feat_idx=None, first_chunk=False):
x = cached_conv(self.conv1, x)
for layer in self.middle:
x = layer(x) # ResidualBlock self-caches; AttentionBlock is cache-free
for layer in self.upsamples:
x = layer(x, first_chunk)
for layer in self.head:
x = cached_conv(layer, x) if isinstance(layer, CausalConv3d) else layer(x)
return unpatchify(x, 2) # fuse the per-frame unpatchify into the compiled region
resamples = []
for mod in vae.model.decoder.modules():
if isinstance(mod, ResidualBlock):
mod.forward = types.MethodType(residual_forward, mod)
elif isinstance(mod, Resample):
mod._fc_state = 0
mod._tcx = None
mod.forward = types.MethodType(resample_forward, mod)
resamples.append(mod)
elif isinstance(mod, Up_ResidualBlock):
mod.forward = types.MethodType(up_residual_forward, mod)
vae.model.decoder.forward = types.MethodType(decoder_forward, vae.model.decoder)
cached_convs = [m for m in vae.model.decoder.modules() if isinstance(m, CausalConv3d)]
info["compile"] = "torch.compile(decoder, max-autotune-no-cudagraphs, break-free, no-coorddesc)"
info["patches"] = info["patches"] + " + per-module-cache break-free decoder"
vae.model.decoder = torch.compile(vae.model.decoder, mode="max-autotune-no-cudagraphs")
m = vae.model
zdim = m.z_dim
scale = vae.scale
def _decode_core(z_gpu_f32):
# Reimplements wrapper.decode -> streaming_decode -> _decode with per-module cache.
for c in cached_convs:
c._cx = None
for r in resamples:
r._fc_state = 0
r._tcx = None
z = z_gpu_f32.unsqueeze(0).to(dtype)
z = z / scale[1].view(1, zdim, 1, 1, 1) + scale[0].view(1, zdim, 1, 1, 1)
x = m.conv2(z)
T = x.shape[2]
outs = []
for i in range(T):
outs.append(m.decoder(x[:, :, i:i + 1], first_chunk=(i == 0)))
return torch.cat(outs, dim=2).clamp_(-1, 1).squeeze(0)
state = {}
@torch.no_grad()
def decode_fn(latent_cpu_f32):
z = latent_cpu_f32.to(device=device) # CPU f32 -> GPU f32
if "graph" not in state:
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for _ in range(3):
_decode_core(z)
torch.cuda.current_stream().wait_stream(s)
torch.cuda.synchronize()
static_in = z.clone()
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
static_out = _decode_core(static_in)
state.update(graph=g, static_in=static_in, static_out=static_out)
state["static_in"].copy_(z)
state["graph"].replay()
return state["static_out"].clone()
# --- END EDIT ZONE -------------------------------------------------------
return decode_fn, info