Skip to content

Commit 46996fa

Browse files
feat: packages-similarity-search (#210)
* chore: update types * chore: update types * feat: add initial semantic search * chore: refine search thresholds * chore: update type * chore: update types * chore: refactor components * wip: ui * feat: add new ui * chore: update types * wip: re-ranker * feat: combined results * chore: update wording * chore: update threshold * chore: bump version * chore: update label * chore: update label
1 parent bee3f32 commit 46996fa

15 files changed

+851
-437
lines changed

web/app/data/author.service.ts

Lines changed: 31 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,19 +203,41 @@ export class AuthorService {
203203
* @returns
204204
*/
205205
static async searchAuthors(query: string, options?: { limit?: number }) {
206-
const { limit = 20 } = options || {};
206+
const { limit = 8 } = options || {};
207+
208+
const [fts, exact] = await Promise.all([
209+
supabase.rpc("find_closest_authors", {
210+
search_term: query,
211+
result_limit: limit,
212+
max_levenshtein_distance: 8,
213+
}),
214+
supabase
215+
.from("authors")
216+
.select("id,name")
217+
.ilike("name", query)
218+
.maybeSingle(),
219+
]);
220+
221+
if (fts.error) {
222+
slog.error("Error in searchAuthors", fts.error);
223+
throw fts.error;
224+
}
225+
if (exact.error) {
226+
slog.error("Error in searchAuthors", exact.error);
227+
throw exact.error;
228+
}
207229

208-
const { data, error } = await supabase.rpc("find_closest_authors", {
209-
search_term: query.trim(),
210-
result_limit: limit,
211-
});
230+
const lexical = fts.data;
212231

213-
if (error) {
214-
slog.error("Error in searchAuthors", error);
215-
return [];
232+
if (exact.data) {
233+
lexical.unshift({
234+
id: exact.data.id,
235+
name: exact.data.name,
236+
levenshtein_distance: 0,
237+
});
216238
}
217239

218-
return uniqBy(data, (author) => author.id);
240+
return uniqBy(lexical, (author) => author.id);
219241
}
220242

221243
/**

web/app/data/package.service.ts

Lines changed: 119 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import { Tables } from "./supabase.types.generated";
44
import { supabase } from "./supabase.server";
55
import { slog } from "../modules/observability.server";
66
import { authorIdSchema } from "./author.shape";
7-
import { uniqBy } from "es-toolkit";
7+
import { groupBy, uniqBy } from "es-toolkit";
88
import TTLCache from "@isaacs/ttlcache";
99
import { format, hoursToMilliseconds } from "date-fns";
10+
import { embed } from "ai";
11+
import { google } from "@ai-sdk/google";
1012

1113
type Package = Tables<"cran_packages">;
1214

@@ -164,37 +166,129 @@ export class PackageService {
164166
) {
165167
const { limit = 20 } = options || {};
166168

167-
const [fts, exact] = await Promise.all([
168-
supabase.rpc("find_closest_packages", {
169-
search_term: query,
170-
result_limit: limit,
171-
}),
172-
// ! ilike is expensive, but we want to make sure we get the exact match w/o case sensitivity.
173-
supabase
174-
.from("cran_packages")
175-
.select("id,name")
176-
.ilike("name", query)
177-
.maybeSingle(),
178-
]);
179-
180-
if (fts.error) {
181-
slog.error("Error in searchPackages", fts.error);
182-
return [];
169+
const isSimilaritySearchEnabled = query.length >= 3;
170+
171+
const [packageFTS, packageExact, embeddingSimilarity, embeddingFTS] =
172+
await Promise.all([
173+
supabase.rpc("find_closest_packages", {
174+
search_term: query,
175+
result_limit: limit,
176+
}),
177+
// ! ilike is expensive, but we want to make sure we get the exact match w/o case sensitivity.
178+
supabase
179+
.from("cran_packages")
180+
.select("id,name,synopsis")
181+
.ilike("name", query)
182+
.maybeSingle(),
183+
isSimilaritySearchEnabled
184+
? supabase.rpc("match_package_embeddings", {
185+
query_embedding: await embed({
186+
value: query,
187+
model: google.textEmbeddingModel("text-embedding-004"),
188+
}).then((res) => res.embedding as unknown as string),
189+
match_threshold: 0.4,
190+
match_count: limit,
191+
})
192+
: null,
193+
isSimilaritySearchEnabled
194+
? supabase.rpc("find_closest_package_embeddings", {
195+
search_term: query,
196+
result_limit: limit,
197+
})
198+
: null,
199+
]);
200+
201+
if (packageFTS.error) {
202+
slog.error("Error in searchPackages", packageFTS.error);
203+
throw packageFTS.error;
183204
}
184205

185-
if (exact.error) {
186-
slog.error("Error in searchPackages", exact.error);
187-
return [];
206+
if (packageExact.error) {
207+
slog.error("Error in searchPackages", packageExact.error);
208+
throw packageExact.error;
188209
}
189210

190-
if (exact.data) {
191-
fts.data.unshift({
192-
...exact.data,
193-
levenshtein_distance: 0,
211+
if (packageExact.data) {
212+
packageFTS.data.unshift({
213+
...packageExact.data,
214+
levenshtein_distance: 0.4,
194215
});
195216
}
196217

197-
return uniqBy(fts.data, (item) => item.id);
218+
if (embeddingSimilarity) {
219+
if (embeddingSimilarity.error) {
220+
slog.error("Error in searchPackages", embeddingSimilarity.error);
221+
}
222+
}
223+
224+
// Prefer the exact match over the similarity match.
225+
// Therefore we filter out the similarity match if it's the same as the exact match.
226+
const sources = [
227+
...(embeddingFTS?.data || []),
228+
...(embeddingSimilarity?.data || []),
229+
].filter((item) => {
230+
const hasExactMatch =
231+
packageExact.data && packageExact.data.id === item.cran_package_id;
232+
if (hasExactMatch) {
233+
return false;
234+
}
235+
return true;
236+
});
237+
238+
const lexical = uniqBy(packageFTS.data, (item) => item.id)
239+
.filter((item) => {
240+
return !sources.some((s) => s.cran_package_id === item.id);
241+
})
242+
.map((item) => ({
243+
name: item.name,
244+
synopsis: item.synopsis,
245+
}));
246+
247+
// Group sources by package id and source name, so that multiple hits per source & package
248+
// can be grouped together. `Object.values` is used to convert the object back to an array.
249+
const sourcesByPackage = groupBy(sources, (item) => item.cran_package_id);
250+
const groupedSourcesByPackageIds = Object.entries(sourcesByPackage).map(
251+
([packageId, sources]) => ({
252+
packageId,
253+
sources: groupBy(sources, (item) => item.source_name),
254+
}),
255+
);
256+
257+
// Fetch the package name for each package id.
258+
// This is not done inside the RPC call as we could
259+
// potentially have different package families (CRAN, Bioconductor, etc.).
260+
const groupedSourcesByPackage = await Promise.all(
261+
groupedSourcesByPackageIds.map(async (item) => {
262+
const { data, error } = await supabase
263+
.from("cran_packages")
264+
.select("name,synopsis")
265+
.eq("id", item.packageId)
266+
.maybeSingle();
267+
268+
if (error || !data) {
269+
slog.error("Error in searchPackages", error);
270+
return null;
271+
}
272+
273+
return {
274+
name: data.name,
275+
synopsis: data.synopsis,
276+
sources: Object.entries(item.sources),
277+
};
278+
}),
279+
);
280+
281+
const isSemanticPreferred =
282+
!packageExact.data && isSimilaritySearchEnabled && sources.length > 0;
283+
284+
return {
285+
lexical,
286+
semantic: groupedSourcesByPackage,
287+
combined: isSemanticPreferred
288+
? [...groupedSourcesByPackage, ...lexical]
289+
: [...lexical, ...groupedSourcesByPackage],
290+
isSemanticPreferred,
291+
};
198292
}
199293

200294
private static sanitizeSitemapName(name: string) {

0 commit comments

Comments
 (0)