-
Notifications
You must be signed in to change notification settings - Fork 21
Description
I gave Bumblebee a try today. The idea was to provide predictions on image captioning to classify an image so that a user can use/put pre-filled tags to easily filter his images.
It turns out that the predictions are.....not too bad and quite fast., at least locally.
This is supposed to be a car:
https://dwyl-imgup.s3.eu-west-3.amazonaws.com/40F36F45.webp
Nx.Serving.run(serving, t_img)
#=>predictions: [
%{
label: "beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon",
score: 0.9203662276268005
}
]Testing with a new query string: pred=onto run the model prediction:
curl -X GET http://localhost:4000/api?url=https://dwyl-imgup.s3.eu-west-3.amazonaws.com/40F36F45.webp&w=300&pred=on
{"h":205,"w":300,"url":"https://dwyl-imgup.s3.eu-west-3.amazonaws.com/76F195C6.webp","new_size":11642,"predictions":[{"label":"beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon","score":0.9203662276268005}]],"init_size":79294,"w_origin":960,"h_origin":656,"url_origin":"https://dwyl-imgup.s3.eu-west-3.amazonaws.com/40F36F45.webp"}I tested 3 models: "facebook/deit-base-distilled-patch16-224" and "microsoft/resnet-50" and ""google/vit-base-patch16-224".
I don't know if anyone tested it?
I submit my code in case any reader sees some obvious fault. It runs locally. It is based on this example. I did not try to deploy this, but here is a guide before I forget: you need to set up a temp dir.
#mix.exs
{:bumblebee, "~> 0.4.2"},
{:nx, "~> 0.6.1"},
{:exla, "~> 0.6.1"},
{:axon, "~> 0.6.0"},I decided to run a GenServer to start the serving with the app to load the model, but you can start an Nx.Serving in the Aplpication level as well, something like {Nx.serving, serving: serve(), name: UpImg.Serving} where the function Application.serve defines what is in the GenServer below.
defmodule UpImg.GsPredict do
use GenServer
def start_link(opts) do
{:ok, model} = Keyword.fetch(opts, :model)
GenServer.start_link(__MODULE__, model, name: __MODULE__)
end
def serve, do: GenServer.call(__MODULE__, :serving)
@impl true
def init(model) do
{:ok, model, {:continue, :load_model}}
end
@impl true
def handle_continue(:load_model, model) do
{:ok, resnet} = Bumblebee.load_model({:hf, model})
{:ok, featurizer} = Bumblebee.load_featurizer({:hf, model})
{:noreply,
Bumblebee.Vision.image_classification(resnet, featurizer,
defn_options: [compiler: EXLA],
top_k: 1,
compile: [batch_size: 10]
)}
end
@impl true
def handle_call(:serving, _from, serving) do
{:reply, serving, serving}
end
endand it is started with the app:
children = [
...,
{UpImg.GsPredict, [model: System.fetch_env!("MODEL")]}
]The model - the repo id - is passed as an env var so I can very simply change it..
In the API, I use predict/1 when I upload an image from the browser and run this task in parallel to the S3 upload. It takes a Vix.Vips.Image, a transformation of a binary file:
[EDITED]
def predict(%Vix.Vips.Image{} = image) do
serving = UpImg.GsPredict.serve()
{:ok, %Vix.Tensor{data: data, shape: shape, names: names, type: type}} =
Vix.Vips.Image.write_to_tensor(image)
#{width, height, channels} = shape <- wrong, shape should be HWC. Bug corrected.
t_img = Nx.from_binary(data, type) |> Nx.reshape(shape, names: names)
Nx.Serving.run(serving, t_img)
Task.async(fn -> Nx.Serving.run(serving, t_img) end)
endand use it in the flow:
prediction_task = predict(my_image)
...
%{predictions: predictions} = Task.await(prediction_task)