Skip to content

Commit b8b6241

Browse files
authored
Merge pull request #47 from ndrean/semantic
Semantic search added
2 parents 9641719 + 7b5287d commit b8b6241

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+4262
-766
lines changed

.gitignore

+8-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ erl_crash.dump
1919
# Also ignore archive artifacts (built via "mix archive.build").
2020
*.ez
2121

22+
# Ignore DB dumps
23+
*.db
24+
2225
# Temporary files, for example, from tests.
2326
/tmp/
2427

@@ -37,4 +40,8 @@ npm-debug.log
3740

3841
# Bumblebee model directory
3942
.bumblebee/*
40-
.elixir_ls
43+
.elixir_ls
44+
45+
# KNN index direcotry
46+
priv/static/uploads/indexes.bin
47+

README.md

+2,459-286
Large diffs are not rendered by default.

_comparison/manage_models.exs

+7-4
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,16 @@ defmodule Comparison.Models do
1212
def verify_and_download_model(model, force_download? \\ false) do
1313
case force_download? do
1414
true ->
15-
File.rm_rf!(model.cache_path) # Delete any cached pre-existing model
16-
download_model(model) # Download model
15+
# Delete any cached pre-existing model
16+
File.rm_rf!(model.cache_path)
17+
# Download model
18+
download_model(model)
1719

1820
false ->
1921
# Check if the model cache directory exists or if it's not empty.
2022
# If so, we download the model.
2123
model_location = Path.join(model.cache_path, "huggingface")
24+
2225
if not File.exists?(model_location) or File.ls!(model_location) == [] do
2326
download_model(model)
2427
end
@@ -50,7 +53,7 @@ defmodule Comparison.Models do
5053
# It will load the model and the respective the featurizer, tokenizer and generation config if needed,
5154
# and return a map with all of these at the end.
5255
defp load_offline_model_params(model) do
53-
Logger.info("Loading #{model.name}...")
56+
Logger.info("ℹ️ Loading #{model.name}...")
5457

5558
# Loading model
5659
loading_settings = {:hf, model.name, cache_dir: model.cache_path, offline: true}
@@ -92,7 +95,7 @@ defmodule Comparison.Models do
9295
# Downloads the models according to a given %ModelInfo struct.
9396
# It will load the model and the respective the featurizer, tokenizer and generation config if needed.
9497
defp download_model(model) do
95-
Logger.info("Downloading #{model.name}...")
98+
Logger.info("ℹ️ Downloading #{model.name}...")
9699

97100
# Download model
98101
downloading_settings = {:hf, model.name, cache_dir: model.cache_path}

_comparison/run.exs

+2-2
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ defmodule Benchmark do
8080
coco_dataset_images_path = File.cwd!() |> Path.join("coco_dataset") |> Path.join("*.jpg")
8181
files = Path.wildcard(coco_dataset_images_path)
8282

83-
#coco_dataset_captions =
83+
# coco_dataset_captions =
8484
# File.stream!(File.cwd!() |> Path.join("coco_dataset") |> Path.join("captions.csv"))
8585
# |> CSV.decode!()
8686
# |> Enum.map(& &1)
@@ -120,7 +120,7 @@ defmodule Benchmark do
120120

121121
# Go over each image and make prediction
122122
Enum.each(vips_images_with_captions, fn image ->
123-
Logger.info("Benchmarking image #{image.id}...")
123+
Logger.info("📊 Benchmarking image #{image.id}...")
124124

125125
# Run the prediction
126126
{time_in_microseconds, prediction} =

assets/js/micro.js

+4-6
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,21 @@ export default {
1111
blue = ["bg-blue-500", "hover:bg-blue-700"],
1212
pulseGreen = ["bg-green-500", "hover:bg-green-700", "animate-pulse"];
1313

14-
1514
_this = this;
1615

1716
// Adding event listener for "click" event
1817
recordButton.addEventListener("click", () => {
19-
2018
// Check if it's recording.
2119
// If it is, we stop the record and update the elements.
2220
if (mediaRecorder && mediaRecorder.state === "recording") {
2321
mediaRecorder.stop();
22+
// audioChunks.getAudioTracks()[0].stop();
2423
text.textContent = "Record";
25-
}
24+
}
2625

2726
// Otherwise, it means the user wants to start recording.
2827
else {
2928
navigator.mediaDevices.getUserMedia({ audio: true }).then((stream) => {
30-
3129
// Instantiate MediaRecorder
3230
mediaRecorder = new MediaRecorder(stream);
3331
mediaRecorder.start();
@@ -39,7 +37,7 @@ export default {
3937

4038
// Add "dataavailable" event handler
4139
mediaRecorder.addEventListener("dataavailable", (event) => {
42-
audioChunks.push(event.data);
40+
event.data.size > 0 && audioChunks.push(event.data);
4341
});
4442

4543
// Add "stop" event handler for when the recording stops.
@@ -57,4 +55,4 @@ export default {
5755
}
5856
});
5957
},
60-
};
58+
};

config/config.exs

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ config :app,
1313
generators: [timestamp_type: :utc_datetime]
1414

1515
# Tells `NX` to use `EXLA` as backend
16-
# config :nx, default_backend: EXLA.Backend
16+
# config :nx, default_backend: EXLA.Backend
1717
# needed to run on `Fly.io`
1818
config :nx, :default_backend, {EXLA.Backend, client: :host}
1919

config/dev.exs

-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ config :app, App.Repo,
1010
show_sensitive_data_on_connection_error: true,
1111
pool_size: 10
1212

13-
1413
# For development, we disable any cache and enable
1514
# debugging and code reloading.
1615
#

config/test.exs

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ config :logger, level: :warning
2626
# Initialize plugs at runtime for faster test compilation
2727
config :phoenix, :plug_init_mode, :runtime
2828

29-
3029
# App configuration
3130
config :app,
31+
start_genserver: false,
32+
knnindex_indices_test: true,
3233
use_test_models: true

deployment.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -926,7 +926,7 @@ defmodule App.Models do
926926
# It will load the model and the respective the featurizer, tokenizer and generation config if needed,
927927
# and return a map with all of these at the end.
928928
defp load_offline_model(model) do
929-
Logger.info("Loading #{model.name}...")
929+
Logger.info("ℹ️ Loading #{model.name}...")
930930

931931
# Loading model
932932
loading_settings = {:hf, model.name, cache_dir: model.cache_path, offline: true}
@@ -968,7 +968,7 @@ defmodule App.Models do
968968
# Downloads the models according to a given %ModelInfo struct.
969969
# It will load the model and the respective the featurizer, tokenizer and generation config if needed.
970970
defp download_model(model) do
971-
Logger.info("Downloading #{model.name}...")
971+
Logger.info("ℹ️ Downloading #{model.name}...")
972972

973973
# Download model
974974
downloading_settings = {:hf, model.name, cache_dir: model.cache_path}

example.txt

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
Mix.install([
2+
{:bumblebee, "~> 0.4.2"},
3+
{:exla, "~> 0.6.4"},
4+
{:nx, "~> 0.6.4 "},
5+
{:hnswlib, "~> 0.1.4"}
6+
])
7+
8+
Nx.global_default_backend(EXLA.Backend)
9+
10+
{:ok, index} = HNSWLib.Index.new(_space = :cosine, _dim = 384, _max_elements = 200)
11+
transformer = "sentence-transformers/paraphrase-MiniLM-L6-v2"
12+
{:ok, %{model: _model, params: _params} = model_info} =
13+
Bumblebee.load_model({:hf, transformer})
14+
15+
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, transformer})
16+
serving = Bumblebee.Text.TextEmbedding.text_embedding(
17+
model_info,
18+
tokenizer,
19+
defn_options: [compiler: EXLA, lazy_transfers: :never]
20+
#output_pool: :mean_pooling,
21+
#output_attribute: :hidden_state,
22+
#embedding_processor: :l2_norm,
23+
)
24+
25+
%{embedding: data} = Nx.Serving.run(serving, "small") |>dbg()
26+
HNSWLib.Index.add_items(index, data)
27+
HNSWLib.Index.get_count(index) |> dbg()
28+
29+
%{embedding: data} = Nx.Serving.run(serving, "tall") |> dbg()
30+
HNSWLib.Index.add_items(index, data)
31+
HNSWLIb.Index.get_count(index) |> dbg()
32+
33+
%{embedding: data} = Nx.Serving.run(serving, "high")
34+
{:ok, labels, distances} = HNSWLib.Index.knn_query(index, data, k: 1) |> dbg()
35+
idx = Nx.to_flat_list(labels[0])
36+
{:ok, dt} = HNSWLib.Index.get_items(index, idx)
37+
Nx.stack(Enum.map(dt, fn d -> Nx.from_binary(d, :f32) end))
38+
39+
defmodule Embedding do
40+
use GenServer
41+
@indexes "indexes.bin"
42+
43+
def start_link(norm) do
44+
GenServer.start_link(__MODULE__, norm, name: __MODULE__)
45+
end
46+
47+
# upload or create a new index file
48+
def init(norm) do
49+
space = norm
50+
51+
{:ok, index} =
52+
case File.exists?(@indexes) do
53+
false ->
54+
HNSWLib.Index.new(_space = space, _dim = 384, _max_elements = 200)
55+
56+
true ->
57+
HNSWLib.Index.load_index(space, 384, @indexes)
58+
end
59+
60+
model_info = nil
61+
tokenizer = nil
62+
{:ok, {model_info, tokenizer, index}, {:continue, :load}}
63+
end
64+
65+
def handle_continue(:load, {_, _, index}) do
66+
transformer = "sentence-transformers/paraphrase-MiniLM-L6-v2"
67+
68+
{:ok, %{model: _model, params: _params} = model_info} =
69+
Bumblebee.load_model({:hf, transformer})
70+
71+
{:ok, tokenizer} =
72+
Bumblebee.load_tokenizer({:hf, transformer})
73+
74+
{:noreply, {model_info, tokenizer, index}}
75+
end
76+
77+
def serve() do
78+
GenServer.call(__MODULE__, :serve)
79+
end
80+
81+
def get_count do
82+
GenServer.call(__MODULE__, :get_count)
83+
end
84+
85+
def get_index do
86+
GenServer.call(__MODULE__, :get_index)
87+
end
88+
89+
def handle_call(:serve, _from, {model_info, tokenizer, index} = state) do
90+
serving = Bumblebee.Text.TextEmbedding.text_embedding(
91+
model_info,
92+
tokenizer,
93+
output_pool: :mean_pooling,
94+
output_attribute: :hidden_state,
95+
embedding_processor: :l2_norm,
96+
defn_options: [compiler: EXLA, lazy_transfers: :never]
97+
)
98+
{:reply, {serving, index}, state}
99+
end
100+
101+
def handle_call(:get_count, _, {_, _, index} = state) do
102+
{:ok, count} = HNSWLib.Index.get_current_count(index)
103+
{:reply, count, state}
104+
end
105+
106+
def handle_call(:get_index, _, {_, _, index} = state) do
107+
{:reply, index, state}
108+
end
109+
end
110+
111+
{:ok, pid} = GenServer.start_link(Embedding, :l2)
112+
113+
{serving, index} = GenServer.call(pid, :serve)
114+
115+
%{embedding: data} = Nx.Serving.run(serving, "small") |>dbg()
116+
HNSWLib.Index.add_items(index, data)
117+
GenServer.call(pid, :get_count) |> dbg()
118+
119+
%{embedding: data} = Nx.Serving.run(serving, "tall") |> dbg()
120+
HNSWLib.Index.add_items(index, data)
121+
GenServer.call(pid, :get_count) |> dbg()
122+
123+
%{embedding: data3} = Nx.Serving.run(serving, "high")
124+
{:ok, labels, distances} = HNSWLib.Index.knn_query(index, data, k: 1) |> dbg()
125+
idx = Nx.to_flat_list(labels[0])
126+
{:ok, dt} = HNSWLib.Index.get_items(index, idx)
127+
Nx.stack(Enum.map(dt, fn d -> Nx.from_binary(d, :f32) end))

lib/app/application.ex

+40-12
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,28 @@ defmodule App.Application do
55
require Logger
66
use Application
77

8-
@impl true
9-
def start(_type, _args) do
8+
@upload_dir Application.app_dir(:app, ["priv", "static", "uploads"])
109

10+
@saved_index if Application.compile_env(:app, :knnindex_indices_test, false),
11+
do: Path.join(@upload_dir, "indexes_test.bin"),
12+
else: Path.join(@upload_dir, "indexes.bin")
13+
14+
def check_models_on_startup do
1115
App.Models.verify_and_download_models()
16+
|> case do
17+
{:error, msg} ->
18+
Logger.error("⚠️ #{msg}")
19+
System.stop(0)
20+
21+
:ok ->
22+
Logger.info("ℹ️ Models: ✅")
23+
:ok
24+
end
25+
end
26+
27+
@impl true
28+
def start(_type, _args) do
29+
:ok = check_models_on_startup()
1230

1331
children = [
1432
# Start the Telemetry supervisor
@@ -18,17 +36,16 @@ defmodule App.Application do
1836
# Start the PubSub system
1937
{Phoenix.PubSub, name: App.PubSub},
2038
# Nx serving for the embedding
21-
# App.TextEmbedding,
22-
39+
{Nx.Serving, serving: App.Models.embedding(), name: Embedding, batch_size: 1},
2340
# Nx serving for Speech-to-Text
2441
{Nx.Serving,
25-
serving:
26-
if Application.get_env(:app, :use_test_models) == true do
27-
App.Models.audio_serving_test()
28-
else
29-
App.Models.audio_serving()
30-
end,
31-
name: Whisper},
42+
serving:
43+
if Application.get_env(:app, :use_test_models) == true do
44+
App.Models.audio_serving_test()
45+
else
46+
App.Models.audio_serving()
47+
end,
48+
name: Whisper},
3249
# Nx serving for image classifier
3350
{Nx.Serving,
3451
serving:
@@ -39,7 +56,7 @@ defmodule App.Application do
3956
end,
4057
name: ImageClassifier},
4158
{GenMagic.Server, name: :gen_magic},
42-
59+
4360
# Adding a supervisor
4461
{Task.Supervisor, name: App.TaskSupervisor},
4562
# Start the Endpoint (http/https)
@@ -48,6 +65,17 @@ defmodule App.Application do
4865
# {App.Worker, arg}
4966
]
5067

68+
# We are starting the HNSWLib Index GenServer only during testing.
69+
# Because this GenServer needs the database to be seeded first,
70+
# we only add it when we're not testing.
71+
# When testing, you need to spawn this process manually (it is done in the test_helper.exs file).
72+
children =
73+
if Application.get_env(:app, :start_genserver, true) == true do
74+
Enum.concat(children, [{App.KnnIndex, [space: :cosine, index: @saved_index]}])
75+
else
76+
children
77+
end
78+
5179
# See https://hexdocs.pm/elixir/Supervisor.html
5280
# for other strategies and supported options
5381
opts = [strategy: :one_for_one, name: App.Supervisor]

0 commit comments

Comments
 (0)