Skip to content

Commit d16752e

Browse files
authored
Add TPU example for Ray Serve with Stable Diffusion
2 parents 5a10842 + f6f3577 commit d16752e

File tree

3 files changed

+276
-0
lines changed

3 files changed

+276
-0
lines changed
3.94 MB
Loading
+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
"""Ray Serve Stable Diffusion example."""
2+
from io import BytesIO
3+
from typing import List
4+
from fastapi import FastAPI
5+
from fastapi.responses import Response
6+
import logging
7+
import ray
8+
from ray import serve
9+
import time
10+
11+
app = FastAPI()
12+
_MAX_BATCH_SIZE = 64
13+
14+
logger = logging.getLogger("ray.serve")
15+
16+
@serve.deployment(num_replicas=1)
17+
@serve.ingress(app)
18+
class APIIngress:
19+
def __init__(self, diffusion_model_handle) -> None:
20+
self.handle = diffusion_model_handle
21+
22+
@app.get(
23+
"/imagine",
24+
responses={200: {"content": {"image/png": {}}}},
25+
response_class=Response,
26+
)
27+
async def generate(self, prompt: str):
28+
assert len(prompt), "prompt parameter cannot be empty"
29+
30+
image = await self.handle.generate.remote(prompt)
31+
return image
32+
33+
34+
@serve.deployment(
35+
ray_actor_options={
36+
"resources": {"TPU": 4},
37+
},
38+
)
39+
class StableDiffusion:
40+
"""FLAX Stable Diffusion Ray Serve deployment running on TPUs.
41+
42+
Attributes:
43+
run_with_profiler: Whether or not to run with the profiler. Note that
44+
this saves the profile to the separate TPU VM.
45+
46+
"""
47+
48+
def __init__(
49+
self, run_with_profiler: bool = False, warmup: bool = False,
50+
warmup_batch_size: int = _MAX_BATCH_SIZE):
51+
from diffusers import FlaxStableDiffusionPipeline
52+
from flax.jax_utils import replicate
53+
import jax
54+
import jax.numpy as jnp
55+
from jax import pmap
56+
57+
model_id = "CompVis/stable-diffusion-v1-4"
58+
59+
self._pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
60+
model_id,
61+
revision="bf16",
62+
dtype=jnp.bfloat16)
63+
64+
self._p_params = replicate(params)
65+
self._p_generate = pmap(self._pipeline._generate)
66+
self._run_with_profiler = run_with_profiler
67+
self._profiler_dir = "/tmp/tensorboard"
68+
69+
if warmup:
70+
logger.info("Sending warmup requests.")
71+
warmup_prompts = ["A warmup request"] * warmup_batch_size
72+
self.generate_tpu(warmup_prompts)
73+
74+
def generate_tpu(self, prompts: List[str]):
75+
"""Generates a batch of images from Diffusion from a list of prompts.
76+
77+
Args:
78+
prompts: a list of strings. Should be a factor of 4.
79+
80+
Returns:
81+
A list of PIL Images.
82+
"""
83+
from flax.training.common_utils import shard
84+
import jax
85+
import numpy as np
86+
87+
rng = jax.random.PRNGKey(0)
88+
rng = jax.random.split(rng, jax.device_count())
89+
90+
assert prompts, "prompt parameter cannot be empty"
91+
logger.info("Prompts: %s", prompts)
92+
prompt_ids = self._pipeline.prepare_inputs(prompts)
93+
prompt_ids = shard(prompt_ids)
94+
logger.info("Sharded prompt ids has shape: %s", prompt_ids.shape)
95+
if self._run_with_profiler:
96+
jax.profiler.start_trace(self._profiler_dir)
97+
98+
time_start = time.time()
99+
images = self._p_generate(prompt_ids, self._p_params, rng)
100+
images = images.block_until_ready()
101+
elapsed = time.time() - time_start
102+
if self._run_with_profiler:
103+
jax.profiler.stop_trace()
104+
105+
logger.info("Inference time (in seconds): %f", elapsed)
106+
logger.info("Shape of the predictions: %s", images.shape)
107+
images = images.reshape(
108+
(images.shape[0] * images.shape[1],) + images.shape[-3:])
109+
logger.info("Shape of images afterwards: %s", images.shape)
110+
return self._pipeline.numpy_to_pil(np.array(images))
111+
112+
@serve.batch(batch_wait_timeout_s=10, max_batch_size=_MAX_BATCH_SIZE)
113+
async def batched_generate_handler(self, prompts: List[str]):
114+
"""Sends a batch of prompts to the TPU model server.
115+
116+
This takes advantage of @serve.batch, Ray Serve's built-in batching
117+
mechanism.
118+
119+
Args:
120+
prompts: A list of input prompts
121+
122+
Returns:
123+
A list of responses which contents are raw PNG.
124+
"""
125+
logger.info("Number of input prompts: %d", len(prompts))
126+
num_to_pad = _MAX_BATCH_SIZE - len(prompts)
127+
prompts += ["Scratch request"] * num_to_pad
128+
129+
images = self.generate_tpu(prompts)
130+
results = []
131+
for image in images[: _MAX_BATCH_SIZE - num_to_pad]:
132+
file_stream = BytesIO()
133+
image.save(file_stream, "PNG")
134+
results.append(
135+
Response(content=file_stream.getvalue(), media_type="image/png")
136+
)
137+
return results
138+
139+
async def generate(self, prompt):
140+
return await self.batched_generate_handler(prompt)
141+
142+
143+
diffusion_bound = StableDiffusion.bind()
144+
deployment = APIIngress.bind(diffusion_bound)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import argparse
2+
from concurrent import futures
3+
import functools
4+
from io import BytesIO
5+
import numpy as np
6+
from PIL import Image
7+
import requests
8+
from tqdm import tqdm
9+
10+
11+
_PROMPTS = [
12+
"Labrador in the style of Hokusai",
13+
"Painting of a squirrel skating in New York",
14+
"HAL-9000 in the style of Van Gogh",
15+
"Times Square under water, with fish and a dolphin swimming around",
16+
"Ancient Roman fresco showing a man working on his laptop",
17+
"Armchair in the shape of an avocado",
18+
"Clown astronaut in space, with Earth in the background",
19+
"A cat sitting on a windowsill",
20+
"A dog playing fetch in a park",
21+
"A city skyline at night",
22+
"A field of flowers in bloom",
23+
"A tropical beach with palm trees",
24+
"A snowy mountain range",
25+
"A waterfall cascading into a pool",
26+
"A forest at sunset",
27+
"A desert landscape with cacti",
28+
"A volcano erupting",
29+
"A lightning storm in the distance",
30+
"A rainbow over a rainbow",
31+
"A unicorn grazing in a meadow",
32+
"A dragon flying through the sky",
33+
"A mermaid swimming in the ocean",
34+
"A robot walking down the street",
35+
"A UFO landing in a field",
36+
"A portal to another dimension",
37+
"A time traveler from the future",
38+
"A talking cat",
39+
"A bowl of fruit on a table",
40+
"A group of friends laughing",
41+
"A family sitting down for dinner",
42+
"A couple kissing in the rain",
43+
"A child playing with a toy",
44+
"A musician playing an instrument",
45+
"A painter painting a picture",
46+
"A writer writing a book",
47+
"A scientist conducting an experiment",
48+
"A construction worker building a house",
49+
"A doctor operating on a patient",
50+
"A teacher teaching a class",
51+
"A police officer arresting a suspect",
52+
"A firefighter putting out a fire",
53+
"A soldier fighting in a war",
54+
"A farmer working in a field",
55+
"A pilot flying a plane",
56+
"An astronaut in space",
57+
"A unicorn eating a rainbow"
58+
]
59+
60+
61+
def send_request_and_receive_image(prompt: str, url: str) -> BytesIO:
62+
"""Sends a single prompt request and returns the Image."""
63+
try:
64+
inputs = "%20".join(prompt.split(" "))
65+
resp = requests.get(f"{url}?prompt={inputs}")
66+
resp.raise_for_status()
67+
return BytesIO(resp.content)
68+
except requests.RequestException as e:
69+
print(f"An error occurred while sending the request: {e}")
70+
71+
72+
def image_grid(imgs, rows, cols):
73+
w, h = imgs[0].size
74+
grid = Image.new("RGB", size=(cols * w, rows * h))
75+
for i, img in enumerate(imgs):
76+
grid.paste(img, box=(i % cols * w, i // cols * h))
77+
return grid
78+
79+
80+
def send_requests(num_requests: int, batch_size: int, save_pictures: bool,
81+
url: str = "http://localhost:8000/imagine"):
82+
"""Sends a list of requests and processes the responses."""
83+
print("num_requests: ", num_requests)
84+
print("batch_size: ", batch_size)
85+
print("url: ", url)
86+
print("save_pictures: ", save_pictures)
87+
88+
prompts = _PROMPTS
89+
if num_requests > len(_PROMPTS):
90+
# Repeat until larger than num_requests
91+
prompts = _PROMPTS * int(np.ceil(num_requests / len(_PROMPTS)))
92+
93+
prompts = np.random.choice(
94+
prompts, num_requests, replace=False)
95+
96+
with futures.ThreadPoolExecutor(max_workers=batch_size) as executor:
97+
raw_images = list(
98+
tqdm(
99+
executor.map(
100+
functools.partial(send_request_and_receive_image, url=url),
101+
prompts,
102+
),
103+
total=len(prompts),
104+
)
105+
)
106+
107+
if save_pictures:
108+
print("Saving pictures to diffusion_results.png")
109+
images = [Image.open(raw_image) for raw_image in raw_images]
110+
grid = image_grid(images, 2, num_requests // 2)
111+
grid.save("./diffusion_results.png")
112+
113+
114+
if __name__ == "__main__":
115+
parser = argparse.ArgumentParser(description="Sends requests to Diffusion.")
116+
parser.add_argument(
117+
"--num_requests", help="Number of requests to send.",
118+
default=8)
119+
parser.add_argument(
120+
"--batch_size", help="The number of requests to send at a time.",
121+
default=8)
122+
parser.add_argument(
123+
"--save_pictures", default=False, action="store_true",
124+
help="Whether to save the generated pictures to disk.")
125+
parser.add_argument(
126+
"--ip", help="The IP address to send the requests to.")
127+
128+
args = parser.parse_args()
129+
130+
send_requests(
131+
num_requests=int(args.num_requests), batch_size=int(args.batch_size),
132+
save_pictures=bool(args.save_pictures))

0 commit comments

Comments
 (0)