Skip to content

Support offloading encode, for generate() with much less VRAM#269

Open
drdaxxy wants to merge 1 commit intoborisdayma:mainfrom
drdaxxy:generate-accept-sc-encode
Open

Support offloading encode, for generate() with much less VRAM#269
drdaxxy wants to merge 1 commit intoborisdayma:mainfrom
drdaxxy:generate-accept-sc-encode

Conversation

@drdaxxy
Copy link

@drdaxxy drdaxxy commented Jun 19, 2022

generate() from Transformers can take encoder outputs as kwargs instead of running the encoder. This PR extends this to "super conditioning" sampling. It also enables providing only one "null sequence" per batch, as inputs or encoder state, since that prompt is normally constant.

How is this useful? We only need to run the encoder once per distinct prompt, which even on a household CPU takes 1-2 seconds for a single input (worst case, no batching, no reuse). Offloading this step, generate works without 2 or 4 gigabytes of encoder weights (mega-1 and mega-1-fp16, respectively) hogging VRAM.

That way, mega-1-fp16 can run on a 4GB GPU (1-batches, without VQGAN, which is fast enough on CPU) and full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

Specifically, without VQGAN, 1-batches need 3728 MiB in float16, 6770 MiB in float32 this way. GPU-accelerating VQGAN adds 770 MiB, assuming we also del vqgan_params["encoder"] (we never need these for generating images) before replicate(vqgan_params) or the like.

On systems that have enough memory anyway, up to 10 (fp32) or 20 (fp16) more items fit in a batch. Given the CPU encode cost, that's a few percent slower or faster (especially combined with other tricks in #247) in my experience, depending on how much state is shared.

@Kepler-Br
Copy link

Sounds awesome!
But as for someone who just want to try it out I'm not able to figure it out fast enough how to use offloading
Could you please add a usage example?

@drdaxxy
Copy link
Author

drdaxxy commented Jun 19, 2022

Could you please add a usage example?

I don't have time to write a proper example now, sorry... I'm hoping another developer decides to take care of that.

@TakuSmash
Copy link

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

@Kepler-Br
Copy link

Kepler-Br commented Jun 21, 2022

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

I guess so:

full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

@TakuSmash
Copy link

TakuSmash commented Jun 21, 2022

Could this even get the full one working on a much smaller GPU VRAM too ? the full mega checkpoint instead of just the fp?

I guess so:

full-precision mega-1 can run on an 8GB GPU (1-batches with VQGAN, up to 3-batches without).

wait so my RTX 3060 should already be good to go for running this in something like Visions of Chaos? The full checkpoint?

@borisdayma
Copy link
Owner

Those are very interesting ideas @drdaxxy !

I'm gonna try to think about how to integrate it in a clean way.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants