Skip to content

Commit 3e9b0f7

Browse files
authored
Merge pull request #60 from traPtitech/staging
feat: 引用メッセージのコンテキストへの取り込み
2 parents 95da448 + 6c3f4fc commit 3e9b0f7

File tree

9 files changed

+227
-6
lines changed

9 files changed

+227
-6
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ require (
2828
github.com/kr/text v0.2.0 // indirect
2929
github.com/mfridman/interpolate v0.0.2 // indirect
3030
github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a // indirect
31+
github.com/motoki317/sc v1.8.2 // indirect
3132
github.com/pkg/errors v0.9.1 // indirect
3233
github.com/rogpeppe/go-internal v1.14.1 // indirect
3334
github.com/sethvargo/go-retry v0.3.0 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866
8484
github.com/milvus-io/milvus-proto/go-api/v2 v2.4.10-0.20240819025435-512e3b98866a/go.mod h1:1OIl0v5PQeNxIJhCvY+K55CBUOYDZevw9g9380u1Wek=
8585
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2 h1:Xqf+S7iicElwYoS2Zly8Nf/zKHuZsNy1xQajfdtygVY=
8686
github.com/milvus-io/milvus-sdk-go/v2 v2.4.2/go.mod h1:ulO1YUXKH0PGg50q27grw048GDY9ayB4FPmh7D+FFTA=
87+
github.com/motoki317/sc v1.8.2 h1:JzhmFKl4ZS0VxuRYRBQ07o3DcGdvhn2NwqnUcPJmCjY=
88+
github.com/motoki317/sc v1.8.2/go.mod h1:IwywgSXTlBxHV8a6lHNiQYmTBh7Dc4f9KjzXVdl8/Bk=
8789
github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4=
8890
github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
8991
github.com/openai/openai-go/v2 v2.7.1 h1:/tfvTJhfv7hTSL8mWwc5VL4WLLSDL5yn9VqVykdu9r8=

internal/bot/channel.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package bot
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/motoki317/sc"
8+
)
9+
10+
func getChannelPathInternal(ctx context.Context, channelID string) (string, error) {
11+
bot := GetBot()
12+
13+
path, _, err := bot.API().ChannelAPI.GetChannelPath(ctx, channelID).Execute()
14+
if err != nil {
15+
return "", err
16+
}
17+
18+
return path.Path, nil
19+
}
20+
21+
var channelPathCache = sc.NewMust(getChannelPathInternal, time.Hour, time.Hour)
22+
23+
func GetChannelPath(channelID string) (string, error) {
24+
return channelPathCache.Get(context.Background(), channelID)
25+
}

internal/bot/user.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package bot
2+
3+
import (
4+
"context"
5+
"time"
6+
7+
"github.com/motoki317/sc"
8+
"github.com/traPtitech/go-traq"
9+
)
10+
11+
func getUserInternal(ctx context.Context, userID string) (*traq.UserDetail, error) {
12+
bot := GetBot()
13+
14+
user, _, err := bot.API().UserAPI.GetUser(ctx, userID).Execute()
15+
if err != nil {
16+
return nil, err
17+
}
18+
19+
return user, nil
20+
}
21+
22+
var userCache = sc.NewMust(getUserInternal, time.Hour, time.Hour)
23+
24+
func GetUser(userID string) (*traq.UserDetail, error) {
25+
return userCache.Get(context.Background(), userID)
26+
}

internal/handler/OnDirectMessageCreated.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package handler
22

33
import (
4+
"log"
5+
46
"github.com/traPtitech/BOT_GPT/internal/bot"
7+
"github.com/traPtitech/BOT_GPT/internal/pkg/formatter"
58
"github.com/traPtitech/traq-ws-bot/payload"
6-
"log"
79
)
810

911
func (h *Handler) DirectMessageReceived() func(p *payload.DirectMessageCreated) {
@@ -16,12 +18,18 @@ func (h *Handler) DirectMessageReceived() func(p *payload.DirectMessageCreated)
1618
return
1719
}
1820

19-
plainTextWithoutMention := bot.RemoveFirstBotID(p.Message.PlainText)
21+
textWithEmbed := formatter.FormatEmbeds(p.Message.Text)
22+
textWithoutMention := bot.RemoveFirstBotID(textWithEmbed)
23+
formattedMessage, err := formatter.FormatQuotedMessage(p.Message.User.ID, textWithoutMention)
24+
if err != nil {
25+
log.Printf("Error formatting quoted message: %v\n", err)
26+
formattedMessage = textWithoutMention
27+
}
2028

2129
if p.Message.User.Name != "pikachu" {
2230
_ = bot.PostMessage(p.Message.ChannelID, "DMではあんまり沢山使わないでね。定期的な`/reset`を忘れない事。")
2331
}
2432

25-
messageReceived(p.Message.Text, plainTextWithoutMention, p.Message.ChannelID)
33+
messageReceived(p.Message.Text, formattedMessage, p.Message.ChannelID)
2634
}
2735
}

internal/handler/OnMessageCreated.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package handler
22

33
import (
4+
"log"
5+
46
"github.com/traPtitech/BOT_GPT/internal/bot"
7+
"github.com/traPtitech/BOT_GPT/internal/pkg/formatter"
58
"github.com/traPtitech/traq-ws-bot/payload"
6-
"log"
79
)
810

911
func (h *Handler) MessageReceived() func(p *payload.MessageCreated) {
@@ -16,8 +18,14 @@ func (h *Handler) MessageReceived() func(p *payload.MessageCreated) {
1618
return
1719
}
1820

19-
plainTextWithoutMention := bot.RemoveFirstBotID(p.Message.PlainText)
21+
textEmbedFormatted := formatter.FormatEmbeds(p.Message.Text)
22+
textWithoutMention := bot.RemoveFirstBotID(textEmbedFormatted)
23+
formattedMessage, err := formatter.FormatQuotedMessage(p.Message.User.ID, textWithoutMention)
24+
if err != nil {
25+
log.Printf("Error formatting quoted message: %v\n", err)
26+
formattedMessage = textWithoutMention
27+
}
2028

21-
messageReceived(p.Message.Text, plainTextWithoutMention, p.Message.ChannelID)
29+
messageReceived(p.Message.Text, formattedMessage, p.Message.ChannelID)
2230
}
2331
}

internal/pkg/formatter/embed.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
package formatter
2+
3+
import "regexp"
4+
5+
var embedRegex = regexp.MustCompile(`!\{"type":"(\w+?)","raw":"(.+?)","id":"[a-z0-9-]+?"\}`)
6+
7+
func FormatEmbeds(content string) string {
8+
return embedRegex.ReplaceAllString(content, "$2")
9+
}
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package formatter
2+
3+
import "testing"
4+
5+
func TestFormatEmbeds(t *testing.T) {
6+
tests := []struct {
7+
name string
8+
input string
9+
expected string
10+
}{
11+
{
12+
name: "No embeds",
13+
input: "Hello, world!",
14+
expected: "Hello, world!",
15+
},
16+
{
17+
name: "Single embed",
18+
input: `Hello, !{"type":"channel","raw":"#general","id":"04ad2c18-fdcb-4c43-beef-82e8ba26ac98"}, world`,
19+
expected: "Hello, #general, world",
20+
},
21+
{
22+
name: "Multiple embeds",
23+
input: `!{"type":"user","raw":"@cp20","id":"be77174f-13c5-4464-8b15-7f45b96d5b18"}!{"type":"channel","raw":"#general","id":"04ad2c18-fdcb-4c43-beef-82e8ba26ac98"}`,
24+
expected: "@cp20#general",
25+
},
26+
}
27+
28+
for _, tt := range tests {
29+
t.Run(tt.name, func(t *testing.T) {
30+
result := FormatEmbeds(tt.input)
31+
if result != tt.expected {
32+
t.Errorf("FormatEmbeds(%q) = %q; expected %q", tt.input, result, tt.expected)
33+
}
34+
})
35+
}
36+
}

internal/pkg/formatter/quotes.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
package formatter
2+
3+
import (
4+
"regexp"
5+
"strings"
6+
"unicode/utf8"
7+
8+
"github.com/traPtitech/BOT_GPT/internal/bot"
9+
"github.com/traPtitech/go-traq"
10+
)
11+
12+
const quoteRegexStr = `\bhttps://q\.trap\.jp/messages/([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})\b`
13+
14+
var quoteRegex = regexp.MustCompile(quoteRegexStr)
15+
16+
var allowingPrefixes = []string{"event", "general", "random", "services", "team/SysAd"}
17+
18+
func isChannelAllowingQuotes(channelID string) (bool, error) {
19+
channelPath, err := bot.GetChannelPath(channelID)
20+
if err != nil {
21+
return false, err
22+
}
23+
24+
for _, prefix := range allowingPrefixes {
25+
if strings.HasPrefix(channelPath, prefix) {
26+
return true, nil
27+
}
28+
}
29+
30+
return false, nil
31+
}
32+
33+
func isUserAllowingQuotes(userID string, messageUserID string) (bool, error) {
34+
if userID == messageUserID {
35+
return true, nil
36+
}
37+
38+
messageUser, err := bot.GetUser(messageUserID)
39+
if err != nil {
40+
return false, err
41+
}
42+
43+
if messageUser.Bot {
44+
return true, nil
45+
}
46+
47+
return false, nil
48+
}
49+
50+
func getQuoteMarkdown(message *traq.Message) (string, error) {
51+
user, err := bot.GetUser(message.UserId)
52+
if err != nil {
53+
return "", err
54+
}
55+
56+
return "> " + user.Name + ":\n> " + strings.ReplaceAll(message.Content, "\n", "\n> "), nil
57+
}
58+
59+
const maxQuoteLength = 10000
60+
61+
func FormatQuotedMessage(userID string, content string) (string, error) {
62+
matches := quoteRegex.FindAllSubmatch([]byte(content), len(content))
63+
messageIDs := make([]string, 0, len(matches))
64+
for _, match := range matches {
65+
if len(match) < 2 {
66+
continue
67+
}
68+
messageID := string(match[1])
69+
messageIDs = append(messageIDs, messageID)
70+
}
71+
72+
var formattedContent strings.Builder
73+
formattedContent.WriteString(quoteRegex.ReplaceAllString(content, ""))
74+
75+
for _, messageID := range messageIDs {
76+
message := bot.GetMessage(messageID)
77+
if message == nil {
78+
continue
79+
}
80+
81+
if utf8.RuneCountInString(message.Content) > maxQuoteLength {
82+
runes := []rune(message.Content)
83+
message.Content = string(runes[:maxQuoteLength]) + "(以下略)"
84+
}
85+
86+
channelAllowed, err := isChannelAllowingQuotes(message.ChannelId)
87+
if err != nil {
88+
return "", err
89+
}
90+
userAllowed, err := isUserAllowingQuotes(userID, message.UserId)
91+
if err != nil {
92+
return "", err
93+
}
94+
if !channelAllowed && !userAllowed {
95+
continue
96+
}
97+
98+
quote, err := getQuoteMarkdown(message)
99+
if err != nil {
100+
return "", err
101+
}
102+
formattedContent.WriteString("\n\n" + quote)
103+
}
104+
105+
return formattedContent.String(), nil
106+
}

0 commit comments

Comments
 (0)