Skip to content

Commit f456ebd

Browse files
committed
diskpersist to filter-replay model, restore unit tests
1 parent 4873ace commit f456ebd

File tree

3 files changed

+603
-124
lines changed

3 files changed

+603
-124
lines changed

cmd/relay/events/diskpersist/diskpersist.go

+126-124
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@ import (
1212
"log/slog"
1313
"os"
1414
"path/filepath"
15+
"slices"
1516
"sync"
17+
"sync/atomic"
1618
"time"
1719

1820
"github.com/bluesky-social/indigo/api/atproto"
@@ -57,6 +59,11 @@ type DiskPersistence struct {
5759
log *slog.Logger
5860

5961
lk sync.Mutex
62+
63+
// takenDownCache is only written with a newly allocated slice; it is always safe to grab a reference to it and use that until done
64+
// takenDownUpdateLock is held when generating a new slice
65+
takenDownUpdateLock sync.Mutex
66+
takenDownCache atomic.Pointer[[]models.Uid]
6067
}
6168

6269
type persistJob struct {
@@ -118,7 +125,14 @@ func NewDiskPersistence(primaryDir, archiveDir string, db *gorm.DB, opts *DiskPe
118125
return nil, fmt.Errorf("failed to create did cache: %w", err)
119126
}
120127

121-
db.AutoMigrate(&LogFileRef{})
128+
err = db.AutoMigrate(&LogFileRef{})
129+
if err != nil {
130+
return nil, fmt.Errorf("gorm setup LogFileRef: %w", err)
131+
}
132+
err = db.AutoMigrate(&DiskPersistTakedown{})
133+
if err != nil {
134+
return nil, fmt.Errorf("gorm setup DiskPersistTakedown: %w", err)
135+
}
122136

123137
bufpool := &sync.Pool{
124138
New: func() any {
@@ -171,6 +185,11 @@ type LogFileRef struct {
171185
SeqStart int64
172186
}
173187

188+
type DiskPersistTakedown struct {
189+
gorm.Model
190+
Uid models.Uid `gorm:"unique"`
191+
}
192+
174193
func (dp *DiskPersistence) SetUidSource(uids UidSource) {
175194
dp.uids = uids
176195
}
@@ -544,37 +563,37 @@ func (dp *DiskPersistence) Persist(ctx context.Context, xevt *events.XRPCStreamE
544563
evtKind = evtKindCommit
545564
did = xevt.RepoCommit.Repo
546565
if err := xevt.RepoCommit.MarshalCBOR(cw); err != nil {
547-
return fmt.Errorf("failed to marshal: %w", err)
566+
return fmt.Errorf("failed to marshal commit: %w", err)
548567
}
549568
case xevt.RepoSync != nil:
550569
evtKind = evtKindSync
551570
did = xevt.RepoSync.Did
552571
if err := xevt.RepoSync.MarshalCBOR(cw); err != nil {
553-
return fmt.Errorf("failed to marshal: %w", err)
572+
return fmt.Errorf("failed to marshal sync: %w", err)
554573
}
555574
case xevt.RepoHandle != nil:
556575
evtKind = evtKindHandle
557576
did = xevt.RepoHandle.Did
558577
if err := xevt.RepoHandle.MarshalCBOR(cw); err != nil {
559-
return fmt.Errorf("failed to marshal: %w", err)
578+
return fmt.Errorf("failed to marshal handle: %w", err)
560579
}
561580
case xevt.RepoIdentity != nil:
562581
evtKind = evtKindIdentity
563582
did = xevt.RepoIdentity.Did
564583
if err := xevt.RepoIdentity.MarshalCBOR(cw); err != nil {
565-
return fmt.Errorf("failed to marshal: %w", err)
584+
return fmt.Errorf("failed to marshal ident: %w", err)
566585
}
567586
case xevt.RepoAccount != nil:
568587
evtKind = evtKindAccount
569588
did = xevt.RepoAccount.Did
570589
if err := xevt.RepoAccount.MarshalCBOR(cw); err != nil {
571-
return fmt.Errorf("failed to marshal: %w", err)
590+
return fmt.Errorf("failed to marshal account: %w", err)
572591
}
573592
case xevt.RepoTombstone != nil:
574593
evtKind = evtKindTombstone
575594
did = xevt.RepoTombstone.Did
576595
if err := xevt.RepoTombstone.MarshalCBOR(cw); err != nil {
577-
return fmt.Errorf("failed to marshal: %w", err)
596+
return fmt.Errorf("failed to marshal tombstone: %w", err)
578597
}
579598
default:
580599
return nil
@@ -679,6 +698,16 @@ func (dp *DiskPersistence) uidForDid(ctx context.Context, did string) (models.Ui
679698
return uid, nil
680699
}
681700

701+
type takedownSet []models.Uid
702+
703+
func (ts *takedownSet) isTakendown(uid models.Uid) bool {
704+
if ts == nil {
705+
return false
706+
}
707+
_, found := slices.BinarySearch(*ts, uid)
708+
return found
709+
}
710+
682711
func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error {
683712
var logs []LogFileRef
684713
needslogs := true
@@ -693,6 +722,9 @@ func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*e
693722
}
694723
}
695724

725+
var takedownUids *takedownSet
726+
takedownUids = (*takedownSet)(dp.takenDownCache.Load())
727+
696728
// playback data from all the log files we found, then check the db to see if more were written during playback.
697729
// repeat a few times but not unboundedly.
698730
// don't decrease '10' below 2 because we should always do two passes through this if the above before-chunk query was used.
@@ -703,7 +735,7 @@ func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*e
703735
}
704736
}
705737

706-
lastSeq, err := dp.PlaybackLogfiles(ctx, since, cb, logs)
738+
lastSeq, err := dp.playbackLogfiles(ctx, since, cb, logs, takedownUids)
707739
if err != nil {
708740
return err
709741
}
@@ -720,9 +752,9 @@ func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*e
720752
return nil
721753
}
722754

723-
func (dp *DiskPersistence) PlaybackLogfiles(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error, logFiles []LogFileRef) (*int64, error) {
755+
func (dp *DiskPersistence) playbackLogfiles(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error, logFiles []LogFileRef, takedownUids *takedownSet) (*int64, error) {
724756
for i, lf := range logFiles {
725-
lastSeq, err := dp.readEventsFrom(ctx, since, filepath.Join(dp.primaryDir, lf.Path), cb)
757+
lastSeq, err := dp.readEventsFrom(ctx, since, filepath.Join(dp.primaryDir, lf.Path), cb, takedownUids)
726758
if err != nil {
727759
return nil, err
728760
}
@@ -746,7 +778,7 @@ func postDoNotEmit(flags uint32) bool {
746778
return false
747779
}
748780

749-
func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn string, cb func(*events.XRPCStreamEvent) error) (*int64, error) {
781+
func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn string, cb func(*events.XRPCStreamEvent) error, takedownUids *takedownSet) (*int64, error) {
750782
fi, err := os.OpenFile(fn, os.O_RDONLY, 0)
751783
if err != nil {
752784
return nil, err
@@ -784,9 +816,9 @@ func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn s
784816

785817
lastSeq = h.Seq
786818

787-
if postDoNotEmit(h.Flags) {
819+
if takedownUids.isTakendown(h.Usr) || postDoNotEmit(h.Flags) {
788820
// event taken down, skip
789-
_, err := io.CopyN(io.Discard, bufr, h.Len64()) // would be really nice if the buffered reader had a 'skip' method that does a seek under the hood
821+
_, err := bufr.Discard(int(h.Len))
790822
if err != nil {
791823
return nil, fmt.Errorf("failed while skipping event (seq: %d, fn: %q): %w", h.Seq, fn, err)
792824
}
@@ -855,132 +887,102 @@ func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn s
855887
}
856888
}
857889

858-
type UserAction struct {
859-
gorm.Model
860-
861-
Usr models.Uid
862-
RebaseAt int64
863-
Takedown bool
890+
func (dp *DiskPersistence) dbSetRepoTakedown(ctx context.Context, usr models.Uid) error {
891+
td := DiskPersistTakedown{
892+
Uid: usr,
893+
}
894+
err := dp.meta.Create(&td).Error
895+
if errors.Is(err, gorm.ErrDuplicatedKey) {
896+
// already there, okay!
897+
return nil
898+
}
899+
return err
864900
}
865-
866-
func (dp *DiskPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error {
867-
/*
868-
if err := p.meta.Create(&UserAction{
869-
Usr: usr,
870-
Takedown: true,
871-
}).Error; err != nil {
872-
return err
901+
func (dp *DiskPersistence) dbClearRepoTakedown(ctx context.Context, usr models.Uid) error {
902+
var takedown DiskPersistTakedown
903+
dp.meta.Model(&takedown).Delete(&takedown)
904+
result := dp.meta.Model(&takedown).First(&takedown, "uid = ?", usr)
905+
if result.Error != nil {
906+
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
907+
// already gone, no problem
908+
return nil
873909
}
874-
*/
875-
876-
return dp.forEachShardWithUserEvents(ctx, usr, func(ctx context.Context, fn string) error {
877-
if err := dp.deleteEventsForUser(ctx, usr, fn); err != nil {
878-
return err
879-
}
880-
881-
return nil
882-
})
910+
}
911+
return result.Error
883912
}
884913

885-
func (dp *DiskPersistence) forEachShardWithUserEvents(ctx context.Context, usr models.Uid, cb func(context.Context, string) error) error {
886-
var refs []LogFileRef
887-
if err := dp.meta.Order("created_at desc").Find(&refs).Error; err != nil {
888-
return err
914+
// return []Uid slice with new Uid in it
915+
func copyUidSliceWithInsert(uids []models.Uid, newUid models.Uid) (newUids []models.Uid, alreadyThere bool) {
916+
insertPos, found := slices.BinarySearch(uids, newUid)
917+
if found {
918+
// TODO: log error, we were expecting it to not already be in the list
919+
return uids, true
889920
}
921+
out := make([]models.Uid, len(uids)+1)
922+
copy(out[:insertPos], uids[:insertPos])
923+
out[insertPos] = newUid
924+
copy(out[insertPos+1:], uids[insertPos:])
925+
return out, false
926+
}
890927

891-
for _, r := range refs {
892-
mhas, err := dp.refMaybeHasUserEvents(ctx, usr, r)
893-
if err != nil {
894-
return err
928+
func (dp *DiskPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error {
929+
takedownsP := dp.takenDownCache.Load()
930+
if takedownsP != nil {
931+
_, found := slices.BinarySearch(*takedownsP, usr)
932+
if found {
933+
// already in cache, okay
934+
return nil
895935
}
896-
897-
if mhas {
898-
var path string
899-
if r.Archived {
900-
path = filepath.Join(dp.archiveDir, r.Path)
901-
} else {
902-
path = filepath.Join(dp.primaryDir, r.Path)
903-
}
904-
905-
if err := cb(ctx, path); err != nil {
906-
return err
907-
}
936+
}
937+
err := dp.dbSetRepoTakedown(ctx, usr)
938+
if err != nil {
939+
return err
940+
}
941+
dp.takenDownUpdateLock.Lock()
942+
defer dp.takenDownUpdateLock.Unlock()
943+
takedownsP = dp.takenDownCache.Load()
944+
if takedownsP == nil {
945+
newTakedowns := make([]models.Uid, 1)
946+
newTakedowns[0] = usr
947+
dp.takenDownCache.Store(&newTakedowns)
948+
} else {
949+
newTakedowns, alreadyThere := copyUidSliceWithInsert(*takedownsP, usr)
950+
if alreadyThere {
951+
return nil
908952
}
953+
dp.takenDownCache.Store(&newTakedowns)
909954
}
910-
911955
return nil
912956
}
913-
914-
func (dp *DiskPersistence) refMaybeHasUserEvents(ctx context.Context, usr models.Uid, ref LogFileRef) (bool, error) {
915-
// TODO: lazily computed bloom filters for users in each logfile
916-
return true, nil
917-
}
918-
919-
type zeroReader struct{}
920-
921-
func (zr *zeroReader) Read(p []byte) (n int, err error) {
922-
for i := range p {
923-
p[i] = 0
957+
func (dp *DiskPersistence) ReverseTakeDownRepo(ctx context.Context, usr models.Uid) error {
958+
takedownsP := dp.takenDownCache.Load()
959+
if takedownsP == nil {
960+
// nothing is there, ignore
961+
return nil
924962
}
925-
return len(p), nil
926-
}
927-
928-
func (dp *DiskPersistence) deleteEventsForUser(ctx context.Context, usr models.Uid, fn string) error {
929-
return dp.mutateUserEventsInLog(ctx, usr, fn, EvtFlagTakedown, true)
930-
}
931-
932-
func (dp *DiskPersistence) mutateUserEventsInLog(ctx context.Context, usr models.Uid, fn string, flag uint32, zeroEvts bool) error {
933-
fi, err := os.OpenFile(fn, os.O_RDWR, 0)
963+
foundPos, found := slices.BinarySearch(*takedownsP, usr)
964+
if !found {
965+
// already gone, okay
966+
return nil
967+
}
968+
err := dp.dbClearRepoTakedown(ctx, usr)
934969
if err != nil {
935-
return fmt.Errorf("failed to open log file: %w", err)
970+
return err
936971
}
937-
defer fi.Close()
938-
defer fi.Sync()
939-
940-
scratch := make([]byte, headerSize)
941-
var offset int64
942-
for {
943-
h, err := readHeader(fi, scratch)
944-
if err != nil {
945-
if errors.Is(err, io.EOF) {
946-
return nil
947-
}
948-
949-
return err
950-
}
951-
952-
if h.Usr == usr && h.Flags&flag == 0 {
953-
nflag := h.Flags | flag
954-
955-
binary.LittleEndian.PutUint32(scratch, nflag)
956-
957-
if _, err := fi.WriteAt(scratch[:4], offset); err != nil {
958-
return fmt.Errorf("failed to write updated flag value: %w", err)
959-
}
960-
961-
if zeroEvts {
962-
// sync that write before blanking the event data
963-
if err := fi.Sync(); err != nil {
964-
return err
965-
}
966-
967-
if _, err := fi.Seek(offset+headerSize, io.SeekStart); err != nil {
968-
return fmt.Errorf("failed to seek: %w", err)
969-
}
970-
971-
_, err := io.CopyN(fi, &zeroReader{}, h.Len64())
972-
if err != nil {
973-
return err
974-
}
975-
}
976-
}
977-
978-
offset += headerSize + h.Len64()
979-
_, err = fi.Seek(offset, io.SeekStart)
980-
if err != nil {
981-
return fmt.Errorf("failed to seek: %w", err)
982-
}
972+
dp.takenDownUpdateLock.Lock()
973+
defer dp.takenDownUpdateLock.Unlock()
974+
takedownsP = dp.takenDownCache.Load()
975+
foundPos, found = slices.BinarySearch(*takedownsP, usr)
976+
if !found {
977+
// already gone, okay
978+
return nil
983979
}
980+
oldTakedowns := *takedownsP
981+
newTakedowns := make([]models.Uid, len(oldTakedowns)-1)
982+
copy(newTakedowns[:foundPos], oldTakedowns[:foundPos])
983+
copy(newTakedowns[foundPos:], oldTakedowns[foundPos+1:])
984+
dp.takenDownCache.Store(&newTakedowns)
985+
return nil
984986
}
985987

986988
func (dp *DiskPersistence) Flush(ctx context.Context) error {

0 commit comments

Comments
 (0)