Skip to content

Commit 881d399

Browse files
Merge pull request #95 from tjmlabs/improved-max-sim
Improved max sim
2 parents c9d29da + f4b9e54 commit 881d399

File tree

4 files changed

+63
-5
lines changed

4 files changed

+63
-5
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Generated by Django 5.1.3 on 2024-11-18 02:51
2+
3+
from django.db import migrations
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("api", "0023_remove_document_base64"),
9+
]
10+
11+
operations = [
12+
migrations.RunSQL(
13+
sql="""
14+
CREATE OR REPLACE FUNCTION max_sim(document halfvec[], query halfvec[]) RETURNS double precision AS $$
15+
WITH queries AS (
16+
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
17+
),
18+
documents AS (
19+
SELECT unnest(document) AS document
20+
),
21+
similarities AS (
22+
SELECT query_number, (document <#> query) * -1 AS similarity
23+
FROM queries CROSS JOIN documents
24+
),
25+
max_similarities AS (
26+
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
27+
)
28+
SELECT SUM(max_similarity) FROM max_similarities;
29+
$$ LANGUAGE SQL;
30+
""",
31+
# Rollback to original function using cosine distance
32+
reverse_sql="""
33+
CREATE OR REPLACE FUNCTION max_sim(document halfvec[], query halfvec[]) RETURNS double precision AS $$
34+
WITH queries AS (
35+
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
36+
),
37+
documents AS (
38+
SELECT unnest(document) AS document
39+
),
40+
similarities AS (
41+
SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents
42+
),
43+
max_similarities AS (
44+
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
45+
)
46+
SELECT SUM(max_similarity) FROM max_similarities;
47+
$$ LANGUAGE SQL;
48+
""",
49+
)
50+
]
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Generated by Django 5.1.3 on 2024-11-19 13:15
2+
3+
from django.db import migrations
4+
5+
6+
class Migration(migrations.Migration):
7+
dependencies = [
8+
("api", "0024_update_max_sim_function"),
9+
]
10+
11+
operations = []

web/api/models.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
from django_stubs_ext.db.models import TypedModelMeta
2121
from pdf2image import convert_from_bytes
2222
from pgvector.django import HalfVectorField
23-
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
24-
wait_fixed)
23+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
2524

2625
logger = logging.getLogger(__name__)
2726

@@ -694,4 +693,3 @@ class PageEmbedding(models.Model):
694693
class MaxSim(Func):
695694
function = "max_sim"
696695
output_field = FloatField()
697-
output_field = FloatField()

web/api/views.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -934,8 +934,7 @@ async def search(
934934
"max_sim",
935935
)
936936
# Normalization
937-
extra_tokens = 12
938-
normalization_factor = query_length + extra_tokens
937+
normalization_factor = query_length
939938

940939
# Format the results
941940
formatted_results = [

0 commit comments

Comments
 (0)