-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhybrid_rrf.rb
More file actions
executable file
·146 lines (124 loc) · 4.73 KB
/
hybrid_rrf.rb
File metadata and controls
executable file
·146 lines (124 loc) · 4.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env ruby
# frozen_string_literal: true
require_relative "setup"
module RrfHelpers
module_function
def rrf_score(weight:, rrf_k:, rank_column:)
"#{weight.to_f}::float8 / (#{rrf_k.to_f}::float8 + #{rank_column}.rank_position)"
end
end
def fulltext_ranked_cte(query, top_k:)
fulltext_source = MockItem.search(:description)
.parse(query, lenient: true)
.with_score
.order(search_score: :desc)
.limit(top_k)
MockItem.from(fulltext_source, :fulltext_source)
.select("fulltext_source.id",
"ROW_NUMBER() OVER (ORDER BY fulltext_source.search_score DESC) AS rank_position")
end
def semantic_ranked_cte(query_embedding, top_k:)
semantic_source = MockItem.nearest_neighbors(:embedding, query_embedding, distance: "cosine")
.limit(top_k)
MockItem.from(semantic_source, :semantic_source)
.select("semantic_source.id",
"ROW_NUMBER() OVER (ORDER BY semantic_source.neighbor_distance ASC) AS rank_position")
end
def bm25_contribution_cte(weight:, rrf_k:)
contribution = RrfHelpers.rrf_score(weight: weight, rrf_k: rrf_k, rank_column: "fulltext")
MockItem.from("fulltext")
.select(
"fulltext.id",
"fulltext.rank_position AS bm25_rank",
"NULL::integer AS semantic_rank",
"#{contribution} AS bm25_rrf",
"0.0::float8 AS semantic_rrf",
"#{contribution} AS hybrid_rrf"
)
end
def semantic_contribution_cte(weight:, rrf_k:)
contribution = RrfHelpers.rrf_score(weight: weight, rrf_k: rrf_k, rank_column: "semantic")
MockItem.from("semantic")
.select(
"semantic.id",
"NULL::integer AS bm25_rank",
"semantic.rank_position AS semantic_rank",
"0.0::float8 AS bm25_rrf",
"#{contribution} AS semantic_rrf",
"#{contribution} AS hybrid_rrf"
)
end
def combined_scores_cte
MockItem.from("contributions")
.select(
"contributions.id",
"MAX(contributions.bm25_rank) AS bm25_rank",
"MAX(contributions.semantic_rank) AS semantic_rank",
"SUM(contributions.bm25_rrf) AS bm25_rrf",
"SUM(contributions.semantic_rrf) AS semantic_rrf",
"SUM(contributions.hybrid_rrf) AS hybrid_score"
)
.group("contributions.id")
end
def hybrid_search(query, top_k: 20, limit: 5, rrf_k: 60, bm25_weight: 1.0, semantic_weight: 1.0)
query_embedding = HybridRrfSetup.query_embedding_for(query)
fulltext_cte = fulltext_ranked_cte(query, top_k: top_k)
semantic_cte = semantic_ranked_cte(query_embedding, top_k: top_k)
bm25_contrib_cte = bm25_contribution_cte(weight: bm25_weight, rrf_k: rrf_k)
semantic_contrib_cte = semantic_contribution_cte(weight: semantic_weight, rrf_k: rrf_k)
scores_cte = combined_scores_cte
MockItem.with(
fulltext: fulltext_cte,
semantic: semantic_cte,
contributions: [bm25_contrib_cte, semantic_contrib_cte],
hybrid_scores: scores_cte
)
.from("hybrid_scores")
.joins("JOIN #{MockItem.table_name} ON #{MockItem.table_name}.id = hybrid_scores.id")
.select(
"#{MockItem.table_name}.id",
"#{MockItem.table_name}.description",
"hybrid_scores.bm25_rank",
"hybrid_scores.semantic_rank",
"hybrid_scores.bm25_rrf",
"hybrid_scores.semantic_rrf",
"hybrid_scores.hybrid_score"
)
.order("hybrid_scores.hybrid_score DESC, #{MockItem.table_name}.id ASC")
.limit(limit)
end
def display_results(query, rows)
puts "\n#{'=' * 80}"
puts "Query: '#{query}'"
puts "=" * 80
if rows.empty?
puts " No results."
return
end
rows.each_with_index do |row, index|
bm25_rank = row.bm25_rank&.to_i
semantic_rank = row.semantic_rank&.to_i
puts format(
" %<rank>d. %<desc>-60s hybrid=%<hybrid>.4f bm25_rank=%<bm25>s semantic_rank=%<semantic>s",
rank: index + 1,
desc: row.description.truncate(60),
hybrid: row.hybrid_score,
bm25: bm25_rank || "--",
semantic: semantic_rank || "--"
)
end
end
if $PROGRAM_NAME == __FILE__
puts "=" * 80
puts "Hybrid Search with Reciprocal Rank Fusion (single SQL query)"
puts "=" * 80
puts "\nCombining ParadeDB DSL + Neighbor DSL in one CTE-based query"
HybridRrfSetup.setup!
MockItem.reset_column_information
["running shoes", "footwear for exercise", "wireless earbuds"].each do |query|
results = hybrid_search(query, top_k: 20, limit: 5).to_a
display_results(query, results)
end
puts "\n#{"=" * 80}"
puts "Done!"
end