Skip to content

Commit 46dba9c

Browse files
committed
rag
1 parent e186a6d commit 46dba9c

File tree

12 files changed

+531
-5
lines changed

12 files changed

+531
-5
lines changed

pom.xml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
<commons-codec.version>1.18.0</commons-codec.version>
2727
<openai-client-jvm.version>4.0.1</openai-client-jvm.version>
2828
<ktor-client-okhttp-jvm.version>3.1.3</ktor-client-okhttp-jvm.version>
29+
<pgvector.version>0.1.6</pgvector.version>
2930
</properties>
3031

3132
<dependencies>
@@ -132,6 +133,12 @@
132133
<version>${commons-codec.version}</version>
133134
</dependency>
134135

136+
<dependency>
137+
<groupId>com.pgvector</groupId>
138+
<artifactId>pgvector</artifactId>
139+
<version>${pgvector.version}</version>
140+
</dependency>
141+
135142
<!--Test-->
136143
<dependency>
137144
<groupId>org.springframework.boot</groupId>

scripts/db.sql

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,3 +327,22 @@ VALUES (1, '/stats_month'),
327327
(37, '/memory')
328328

329329
;
330+
CREATE EXTENSION IF NOT EXISTS vector;
331+
CREATE EXTENSION IF NOT EXISTS pg_trgm;
332+
CREATE TABLE IF NOT EXISTS rag_msg_index (
333+
rag_id BIGSERIAL PRIMARY KEY,
334+
chat_id BIGINT NOT NULL,
335+
msg_id BIGINT NOT NULL,
336+
ts TIMESTAMPTZ NOT NULL,
337+
user_id BIGINT NOT NULL,
338+
text TEXT NOT NULL,
339+
embedding vector(1536),
340+
tsv tsvector
341+
GENERATED ALWAYS AS (to_tsvector('russian', coalesce(text,''))) STORED
342+
);
343+
344+
CREATE INDEX IF NOT EXISTS rag_idx_chat_ts ON rag_msg_index (chat_id, ts DESC);
345+
CREATE INDEX IF NOT EXISTS rag_idx_chat_msg ON rag_msg_index (chat_id, msg_id);
346+
CREATE INDEX IF NOT EXISTS rag_idx_ivf ON rag_msg_index USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);
347+
CREATE INDEX IF NOT EXISTS rag_idx_tsv ON rag_msg_index USING GIN (tsv);
348+
CREATE INDEX IF NOT EXISTS rag_idx_trgm ON rag_msg_index USING GIN (text gin_trgm_ops);

src/main/kotlin/dev/storozhenko/familybot/core/repos/UserRepository.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@ class UserRepository(private val template: JdbcTemplate) {
3232
return template.query(select) { rs, _ -> rs.toUser() }
3333
}
3434

35+
fun getUserNamesById(ids: List<Long>): Map<Long, String> {
36+
return template.query("select id, name, username from users where id in (${ids.joinToString(",")})", { rs, _ ->
37+
rs.getLong("id") to (rs.getString("name") + "|" + (rs.getString("username") ?: "N/A"))
38+
})
39+
.toMap()
40+
41+
}
42+
3543
fun addChat(chat: Chat) {
3644
template.update(
3745
"INSERT INTO chats (id, name) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET name = excluded.name, active = TRUE",

src/main/kotlin/dev/storozhenko/familybot/core/routers/Router.kt

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,17 @@ import dev.storozhenko.familybot.feature.settings.models.CommandLimit
2525
import dev.storozhenko.familybot.feature.settings.models.FirstBotInteraction
2626
import dev.storozhenko.familybot.feature.settings.models.FirstTimeInChat
2727
import dev.storozhenko.familybot.feature.settings.models.MessageCounter
28+
import dev.storozhenko.familybot.feature.settings.models.RagContext
2829
import dev.storozhenko.familybot.feature.settings.repos.FunctionsConfigureRepository
2930
import dev.storozhenko.familybot.feature.talking.services.AdManager
3031
import dev.storozhenko.familybot.feature.talking.services.Dictionary
32+
import dev.storozhenko.familybot.feature.talking.services.rag.RagService
3133
import io.github.oshai.kotlinlogging.KotlinLogging
3234
import kotlinx.coroutines.CoroutineExceptionHandler
3335
import kotlinx.coroutines.CoroutineScope
3436
import kotlinx.coroutines.Dispatchers
3537
import kotlinx.coroutines.SupervisorJob
38+
import kotlinx.coroutines.coroutineScope
3639
import kotlinx.coroutines.launch
3740
import org.springframework.stereotype.Component
3841
import org.telegram.telegrambots.meta.api.objects.Update
@@ -55,6 +58,7 @@ class Router(
5558
private val dictionary: Dictionary,
5659
private val easyKeyValueService: EasyKeyValueService,
5760
private val adManager: AdManager,
61+
private val ragService: RagService,
5862
) {
5963

6064
private val logger = KotlinLogging.logger { }
@@ -78,7 +82,18 @@ class Router(
7882
}
7983
}
8084
val context = update.context(botConfig, dictionary, client)
81-
85+
coroutineScope {
86+
launch {
87+
if (easyKeyValueService.get(
88+
RagContext,
89+
update.toChat().key(),
90+
false
91+
) && botConfig.openAiToken != null
92+
) {
93+
ragService.add(context)
94+
}
95+
}
96+
}
8297
val executor = if (isGroup) {
8398
selectExecutor(context) ?: selectRandom(context)
8499
} else {

src/main/kotlin/dev/storozhenko/familybot/feature/settings/models/SettingTypes.kt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ object IGCookie : StringKeyType<PlainKey>
5454
object PaymentKey : StringKeyType<PlainKey>
5555
object RefundNeedsToPressTime : LongKeyType<PlainKey>
5656
object AdCooldown : InstantKeyType<ChatEasyKey>
57+
object RagContext: BooleanKeyType<ChatEasyKey>
5758

5859
object StoryGameActive : BooleanKeyType<ChatEasyKey>
5960
object StoryPollBlocked : InstantKeyType<ChatEasyKey>
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package dev.storozhenko.familybot.feature.settings.processors
2+
3+
import dev.storozhenko.familybot.common.extensions.getMessageTokens
4+
import dev.storozhenko.familybot.core.keyvalue.EasyKeyValueService
5+
import dev.storozhenko.familybot.core.routers.models.ExecutorContext
6+
import dev.storozhenko.familybot.feature.settings.models.RagContext
7+
import org.springframework.stereotype.Component
8+
9+
@Component
10+
class RagSettingProcessor(
11+
private val easyKeyValueService: EasyKeyValueService,
12+
) : SettingProcessor {
13+
override fun canProcess(context: ExecutorContext): Boolean {
14+
return context.isFromDeveloper && context.update.getMessageTokens()[1] == "rag"
15+
}
16+
17+
override suspend fun process(context: ExecutorContext) {
18+
if (context.update.getMessageTokens().size < 3) {
19+
context.send("Not ok")
20+
return
21+
}
22+
val state = context.update.getMessageTokens()[2]
23+
if (state == "вкл") {
24+
easyKeyValueService.put(RagContext, context.chatKey, true)
25+
context.send("Ok")
26+
return
27+
}
28+
29+
if (state == "выкл") {
30+
easyKeyValueService.put(RagContext, context.chatKey, false)
31+
context.send("Ok")
32+
return
33+
}
34+
35+
context.send("Not ok")
36+
37+
38+
}
39+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package dev.storozhenko.familybot.feature.talking.models
2+
3+
import java.time.Instant
4+
5+
enum class Kind { SEMANTIC, KEYWORD_RU, KEYWORD_SIMPLE, FUZZY, RECENT }
6+
7+
data class RagHit(
8+
val ragId: Long,
9+
val msgId: Long,
10+
val userId: Long,
11+
val ts: Instant,
12+
val text: String,
13+
val score: Double,
14+
val kind: Kind,
15+
)
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
package dev.storozhenko.familybot.feature.talking.repos
2+
3+
import com.aallam.openai.api.embedding.Embedding
4+
import com.pgvector.PGvector
5+
import dev.storozhenko.familybot.core.routers.models.ExecutorContext
6+
import dev.storozhenko.familybot.feature.talking.models.Kind
7+
import dev.storozhenko.familybot.feature.talking.models.RagHit
8+
import org.springframework.jdbc.core.JdbcTemplate
9+
import org.springframework.stereotype.Component
10+
import java.sql.Timestamp
11+
import java.time.Instant
12+
import kotlin.math.exp
13+
14+
@Component
15+
class RagRepository(private val template: JdbcTemplate) {
16+
17+
18+
fun add(executorContext: ExecutorContext, embedding: Embedding) {
19+
20+
template.update(
21+
"""
22+
INSERT INTO rag_msg_index (chat_id, msg_id, ts, user_id, text, embedding)
23+
VALUES (?, ?, ?, ?, ?, ?)
24+
""".trimIndent(),
25+
executorContext.chat.id,
26+
executorContext.message.messageId,
27+
Timestamp.from(Instant.now()),
28+
executorContext.user.id,
29+
executorContext.message.text,
30+
PGvector(embedding.embedding)
31+
)
32+
33+
}
34+
35+
36+
fun searchSemantic(
37+
executorContext: ExecutorContext,
38+
queryEmbedding: Embedding,
39+
limit: Int = 20,
40+
): List<RagHit> {
41+
val vec = PGvector(queryEmbedding.embedding)
42+
val sql = """
43+
SELECT rag_id, msg_id, user_id, ts, text, (embedding <=> ?) AS dist
44+
FROM rag_msg_index
45+
WHERE chat_id = ? AND embedding IS NOT NULL
46+
ORDER BY embedding <=> ?
47+
LIMIT ?
48+
""".trimIndent()
49+
return template.query(
50+
sql,
51+
{ rs, _ ->
52+
val dist = rs.getDouble("dist")
53+
RagHit(
54+
ragId = rs.getLong("rag_id"),
55+
msgId = rs.getLong("msg_id"),
56+
userId = rs.getLong("user_id"),
57+
ts = rs.getTimestamp("ts").toInstant(),
58+
text = rs.getString("text"),
59+
score = 1.0 - dist,
60+
kind = Kind.SEMANTIC
61+
)
62+
},
63+
vec,
64+
executorContext.chat.id,
65+
vec,
66+
limit
67+
)
68+
}
69+
70+
fun searchKeywordRu(
71+
executorContext: ExecutorContext,
72+
q: String,
73+
limit: Int = 20,
74+
): List<RagHit> {
75+
val sql = """
76+
SELECT rag_id, msg_id, user_id, ts, text,
77+
ts_rank(tsv, plainto_tsquery('russian', ?)) AS kw_score
78+
FROM rag_msg_index
79+
WHERE chat_id = ? AND tsv @@ plainto_tsquery('russian', ?)
80+
ORDER BY kw_score DESC
81+
LIMIT ?
82+
""".trimIndent()
83+
return template.query(
84+
sql,
85+
{ rs, _ ->
86+
RagHit(
87+
ragId = rs.getLong("rag_id"),
88+
msgId = rs.getLong("msg_id"),
89+
userId = rs.getLong("user_id"),
90+
ts = rs.getTimestamp("ts").toInstant(),
91+
text = rs.getString("text"),
92+
score = rs.getDouble("kw_score"),
93+
kind = Kind.KEYWORD_RU
94+
)
95+
},
96+
q, executorContext.chat.id, q, limit
97+
)
98+
}
99+
100+
fun searchKeywordSimple(
101+
executorContext: ExecutorContext,
102+
q: String,
103+
limit: Int = 20,
104+
): List<RagHit> {
105+
val sql = """
106+
SELECT rag_id, msg_id, user_id, ts, text,
107+
ts_rank(to_tsvector('simple', coalesce(text,'')),
108+
plainto_tsquery('simple', ?)) AS kw_score
109+
FROM rag_msg_index
110+
WHERE chat_id = ?
111+
AND to_tsvector('simple', coalesce(text,'')) @@ plainto_tsquery('simple', ?)
112+
ORDER BY kw_score DESC
113+
LIMIT ?
114+
""".trimIndent()
115+
return template.query(
116+
sql,
117+
{ rs, _ ->
118+
RagHit(
119+
ragId = rs.getLong("rag_id"),
120+
msgId = rs.getLong("msg_id"),
121+
userId = rs.getLong("user_id"),
122+
ts = rs.getTimestamp("ts").toInstant(),
123+
text = rs.getString("text"),
124+
score = rs.getDouble("kw_score"),
125+
kind = Kind.KEYWORD_SIMPLE
126+
)
127+
},
128+
q, executorContext.chat.id, q, limit
129+
)
130+
}
131+
132+
fun recentWindow(
133+
executorContext: ExecutorContext,
134+
minutes: Long = 30,
135+
limit: Int = 40,
136+
): List<RagHit> {
137+
val sql = """
138+
SELECT rag_id, msg_id, user_id, ts, text
139+
FROM rag_msg_index
140+
WHERE chat_id = ?
141+
AND ts > now() - (? || ' minutes')::interval
142+
ORDER BY ts DESC
143+
LIMIT ?
144+
""".trimIndent()
145+
val now = Instant.now()
146+
return template.query(
147+
sql,
148+
{ rs, _ ->
149+
val ts = rs.getTimestamp("ts").toInstant()
150+
val minutesSince = (now.epochSecond - ts.epochSecond) / 60.0
151+
val rec = exp(-minutesSince / 720.0)
152+
RagHit(
153+
ragId = rs.getLong("rag_id"),
154+
msgId = rs.getLong("msg_id"),
155+
userId = rs.getLong("user_id"),
156+
ts = ts,
157+
text = rs.getString("text"),
158+
score = rec,
159+
kind = Kind.RECENT
160+
)
161+
},
162+
executorContext.chat.id, minutes.toString(), limit
163+
)
164+
}
165+
166+
fun searchFuzzy(
167+
executorContext: ExecutorContext,
168+
q: String,
169+
limit: Int = 10,
170+
): List<RagHit> {
171+
val sql = """
172+
SELECT rag_id, msg_id, user_id, ts, text, similarity(text, ?) AS sim
173+
FROM rag_msg_index
174+
WHERE chat_id = ? AND text % ?
175+
ORDER BY sim DESC
176+
LIMIT ?
177+
""".trimIndent()
178+
return template.query(
179+
sql,
180+
{ rs, _ ->
181+
RagHit(
182+
ragId = rs.getLong("rag_id"),
183+
msgId = rs.getLong("msg_id"),
184+
userId = rs.getLong("user_id"),
185+
ts = rs.getTimestamp("ts").toInstant(),
186+
text = rs.getString("text"),
187+
score = rs.getDouble("sim"),
188+
kind = Kind.FUZZY
189+
)
190+
},
191+
q, executorContext.chat.id, q, limit
192+
)
193+
}
194+
195+
196+
}

0 commit comments

Comments
 (0)