Open
Description
I'm using the Stable Diffusion Configuration and modified it in order to run on an inferentia2 instance.
My inference script is as below
from io import BytesIO
from fastapi import FastAPI
from fastapi.responses import Response
import torch
import torch_neuronx
import os
import base64
from ray import serve
from optimum.neuron import NeuronStableDiffusionXLPipeline
app = FastAPI()
os.environ["NEURON_RT_NUM_CORES"] = "2"
neuron_cores = 2
@serve.deployment(num_replicas=1, route_prefix="/")
@serve.ingress(app)
class APIIngress:
def __init__(self, diffusion_model_handle) -> None:
self.handle = diffusion_model_handle
@app.get(
"/imagine",
responses={200: {"content": {"image/png": {}}}},
response_class=Response,
)
async def generate(self, prompt: str):
assert len(prompt), "prompt parameter cannot be empty"
image_ref = await self.handle.generate.remote(prompt)
image = await image_ref
file_stream = BytesIO()
image.save(file_stream, "PNG")
return Response(content=file_stream.getvalue(), media_type="image/png")
@serve.deployment(
ray_actor_options={
"resources": {"neuron_cores": neuron_cores},
"runtime_env": {"env_vars": {"NEURON_CC_FLAGS": "-O1"}},
},
autoscaling_config={"min_replicas": 1, "max_replicas": 2},
)
class StableDiffusionV2:
def __init__(self):
from optimum.neuron import NeuronStableDiffusionXLPipeline
model_dir = "sdxl_neuron/"
self.pipe = NeuronStableDiffusionXLPipeline.from_pretrained(model_dir, device_ids=[0, 1])
def generate(self, prompt: str):
assert len(prompt), "prompt parameter cannot be empty"
image = self.pipe(prompt).images[0]
return image
entrypoint = APIIngress.bind(StableDiffusionV2.bind())
When I send a request to the endpoint like http://127.0.0.1:8000/imagine?prompt={input}
I get the below error
(ServeReplica:default:APIIngress pid=9170) response = await func(request)
(ServeReplica:default:APIIngress pid=9170) File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/fastapi/routing.py", line 299, in app
(ServeReplica:default:APIIngress pid=9170) raise e
(ServeReplica:default:APIIngress pid=9170) File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/fastapi/routing.py", line 294, in app
(ServeReplica:default:APIIngress pid=9170) raw_response = await run_endpoint_function(
(ServeReplica:default:APIIngress pid=9170) File "/home/ec2-user/aws_neuron_venv_pytorch/lib64/python3.8/site-packages/fastapi/routing.py", line 191, in run_endpoint_function
(ServeReplica:default:APIIngress pid=9170) return await dependant.call(**values)
(ServeReplica:default:APIIngress pid=9170) File "/home/ec2-user/./inference_sd.py", line 38, in generate
(ServeReplica:default:APIIngress pid=9170) image = await image_ref
(ServeReplica:default:APIIngress pid=9170) TypeError: object Image can't be used in 'await' expression
(ServeReplica:default:APIIngress pid=9170) INFO 2024-01-16 16:46:30,563 default_APIIngress LcGOVN 7f82dbfd-bc29-40e6-9326-16029bd06b22 /imagine replica.py:772 - __CALL__ ERROR 14203.6ms
Metadata
Metadata
Assignees
Labels
No labels