Skip to content

Commit

Permalink
Merge pull request #95 from tjmlabs/improved-max-sim
Browse files Browse the repository at this point in the history
Improved max sim
  • Loading branch information
Jonathan-Adly authored Nov 19, 2024
2 parents c9d29da + f4b9e54 commit 881d399
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 5 deletions.
50 changes: 50 additions & 0 deletions web/api/migrations/0024_update_max_sim_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Generated by Django 5.1.3 on 2024-11-18 02:51

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("api", "0023_remove_document_base64"),
]

operations = [
migrations.RunSQL(
sql="""
CREATE OR REPLACE FUNCTION max_sim(document halfvec[], query halfvec[]) RETURNS double precision AS $$
WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
),
documents AS (
SELECT unnest(document) AS document
),
similarities AS (
SELECT query_number, (document <#> query) * -1 AS similarity
FROM queries CROSS JOIN documents
),
max_similarities AS (
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
)
SELECT SUM(max_similarity) FROM max_similarities;
$$ LANGUAGE SQL;
""",
# Rollback to original function using cosine distance
reverse_sql="""
CREATE OR REPLACE FUNCTION max_sim(document halfvec[], query halfvec[]) RETURNS double precision AS $$
WITH queries AS (
SELECT row_number() OVER () AS query_number, * FROM (SELECT unnest(query) AS query)
),
documents AS (
SELECT unnest(document) AS document
),
similarities AS (
SELECT query_number, 1 - (document <=> query) AS similarity FROM queries CROSS JOIN documents
),
max_similarities AS (
SELECT MAX(similarity) AS max_similarity FROM similarities GROUP BY query_number
)
SELECT SUM(max_similarity) FROM max_similarities;
$$ LANGUAGE SQL;
""",
)
]
11 changes: 11 additions & 0 deletions web/api/migrations/0025_auto_20241119_1315.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Generated by Django 5.1.3 on 2024-11-19 13:15

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("api", "0024_update_max_sim_function"),
]

operations = []
4 changes: 1 addition & 3 deletions web/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@
from django_stubs_ext.db.models import TypedModelMeta
from pdf2image import convert_from_bytes
from pgvector.django import HalfVectorField
from tenacity import (retry, retry_if_exception_type, stop_after_attempt,
wait_fixed)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -694,4 +693,3 @@ class PageEmbedding(models.Model):
class MaxSim(Func):
function = "max_sim"
output_field = FloatField()
output_field = FloatField()
3 changes: 1 addition & 2 deletions web/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,8 +934,7 @@ async def search(
"max_sim",
)
# Normalization
extra_tokens = 12
normalization_factor = query_length + extra_tokens
normalization_factor = query_length

# Format the results
formatted_results = [
Expand Down

0 comments on commit 881d399

Please sign in to comment.