-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmessages_by_room.go
More file actions
206 lines (187 loc) · 7.27 KB
/
Copy pathmessages_by_room.go
File metadata and controls
206 lines (187 loc) · 7.27 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
package cassrepo
import (
"context"
"fmt"
"time"
"github.com/gocql/gocql"
"github.com/hmchangw/chat/history-service/internal/models"
)
const baseColumns = "room_id, created_at, message_id, thread_room_id, sender, " +
"msg, mentions, attachments, card, card_action, tshow, tcount, thread_last_msg_at, " +
"thread_parent_id, thread_parent_created_at, quoted_parent_message, " +
"visible_to, reactions, deleted, " +
"type, sys_msg_data, site_id, edited_at, updated_at, pinned_at, " +
"enc_payload, enc_meta"
const messageByRoomQuery = "SELECT " + baseColumns + " FROM messages_by_room"
// startBucketFromCursor returns the walk's start bucket and any in-bucket pageState from the cursor.
// Out-of-range cursor buckets are rejected to prevent tampered cursors from consuming maxBuckets empty reads.
func startBucketFromCursor(pageReq PageRequest, direction walkDirection, defaultBucket, floorBucket int64) (int64, []byte, error) {
if pageReq.Cursor == nil {
return defaultBucket, nil, nil
}
encoded := pageReq.Cursor.Encode()
if encoded == "" {
return defaultBucket, nil, nil
}
bucket, pageState, err := decodeBucketCursor(encoded)
if err != nil {
return 0, nil, fmt.Errorf("start bucket from cursor: %w", err)
}
switch direction {
case walkDesc:
// Legitimate range: floorBucket <= bucket <= defaultBucket.
if bucket > defaultBucket || bucket < floorBucket {
return defaultBucket, nil, nil
}
case walkAsc:
// Legitimate range: defaultBucket <= bucket <= floorBucket (ASC's
// "floor" is the ceiling).
if bucket < defaultBucket || bucket > floorBucket {
return defaultBucket, nil, nil
}
}
return bucket, pageState, nil
}
// scanMessagesUpTo returns a fillPage scan callback that consumes up to
// remaining rows from iter via structScan and decrypts any enc_payload rows in
// place via r.decryptIfNeeded. A decrypt (or scan) error aborts the walk:
// fillPage discards the accumulated rows and propagates the error to the caller.
func (r *Repository) scanMessagesUpTo(ctx context.Context) func(iter *gocql.Iter, remaining int) ([]models.Message, error) {
return func(iter *gocql.Iter, remaining int) ([]models.Message, error) {
out := make([]models.Message, 0, remaining)
for len(out) < remaining {
var m models.Message
ok, err := structScan(iter, &m)
if err != nil {
return nil, err
}
if !ok {
break
}
if err := r.decryptIfNeeded(ctx, &m); err != nil {
return nil, err
}
out = append(out, m)
}
return out, nil
}
}
func (r *Repository) GetMessagesBefore(ctx context.Context, roomID string, before time.Time, floor time.Time, pageReq PageRequest) (Page[models.Message], error) {
floorBucket := r.bucket.Of(floor)
startBucket, initialPageState, err := startBucketFromCursor(pageReq, walkDesc, r.bucket.Of(before), floorBucket)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages before: %w", err)
}
queryFn := func(bucket int64, firstBucket bool) *gocql.Query {
if firstBucket {
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? AND created_at < ? ORDER BY created_at DESC`,
roomID, bucket, before,
)
}
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? ORDER BY created_at DESC`,
roomID, bucket,
)
}
res, err := fillPage[models.Message](
ctx, r.bucket, walkDesc, startBucket, floorBucket, r.walkCfg,
pageReq.PageSize, initialPageState, queryFn, r.scanMessagesUpTo(ctx),
)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages before: %w", err)
}
return res.toPage(), nil
}
func (r *Repository) GetMessagesBetweenDesc(ctx context.Context, roomID string, since, before time.Time, pageReq PageRequest) (Page[models.Message], error) {
floorBucket := r.bucket.Of(since)
startBucket, initialPageState, err := startBucketFromCursor(pageReq, walkDesc, r.bucket.Of(before), floorBucket)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages between desc: %w", err)
}
queryFn := func(bucket int64, firstBucket bool) *gocql.Query {
atFloor := bucket == floorBucket
switch {
case firstBucket && atFloor:
// Single-bucket walk: both upper (before) and lower (since) bounds apply.
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? AND created_at > ? AND created_at < ? ORDER BY created_at DESC`,
roomID, bucket, since, before,
)
case firstBucket:
// Top of walk: upper bound only.
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? AND created_at < ? ORDER BY created_at DESC`,
roomID, bucket, before,
)
case atFloor:
// Bottom of walk: lower bound only — without this, rows with
// created_at <= since in the floor bucket would leak through.
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? AND created_at > ? ORDER BY created_at DESC`,
roomID, bucket, since,
)
default:
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? ORDER BY created_at DESC`,
roomID, bucket,
)
}
}
res, err := fillPage[models.Message](
ctx, r.bucket, walkDesc, startBucket, floorBucket, r.walkCfg,
pageReq.PageSize, initialPageState, queryFn, r.scanMessagesUpTo(ctx),
)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages between desc: %w", err)
}
return res.toPage(), nil
}
func (r *Repository) GetMessagesAfter(ctx context.Context, roomID string, after time.Time, ceiling time.Time, pageReq PageRequest) (Page[models.Message], error) {
ceilingBucket := r.bucket.Of(ceiling)
startBucket, initialPageState, err := startBucketFromCursor(pageReq, walkAsc, r.bucket.Of(after), ceilingBucket)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages after: %w", err)
}
queryFn := func(bucket int64, firstBucket bool) *gocql.Query {
if firstBucket {
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? AND created_at > ? ORDER BY created_at ASC`,
roomID, bucket, after,
)
}
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? ORDER BY created_at ASC`,
roomID, bucket,
)
}
res, err := fillPage[models.Message](
ctx, r.bucket, walkAsc, startBucket, ceilingBucket, r.walkCfg,
pageReq.PageSize, initialPageState, queryFn, r.scanMessagesUpTo(ctx),
)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get messages after: %w", err)
}
return res.toPage(), nil
}
func (r *Repository) GetAllMessagesAsc(ctx context.Context, roomID string, floor time.Time, ceiling time.Time, pageReq PageRequest) (Page[models.Message], error) {
ceilingBucket := r.bucket.Of(ceiling)
startBucket, initialPageState, err := startBucketFromCursor(pageReq, walkAsc, r.bucket.Of(floor), ceilingBucket)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get all messages asc: %w", err)
}
queryFn := func(bucket int64, _ bool) *gocql.Query {
return r.session.Query(
messageByRoomQuery+` WHERE room_id = ? AND bucket = ? ORDER BY created_at ASC`,
roomID, bucket,
)
}
res, err := fillPage[models.Message](
ctx, r.bucket, walkAsc, startBucket, ceilingBucket, r.walkCfg,
pageReq.PageSize, initialPageState, queryFn, r.scanMessagesUpTo(ctx),
)
if err != nil {
return Page[models.Message]{}, fmt.Errorf("get all messages asc: %w", err)
}
return res.toPage(), nil
}