Skip to content

Commit 3fb02f3

Browse files
committed
fix potential cursor issue
1 parent d8b155d commit 3fb02f3

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

server/data_stream.py

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
import logging
22
from collections import defaultdict
33

4-
from atproto import AtUri, CAR, firehose_models, FirehoseSubscribeReposClient, models, parse_subscribe_repos_message
4+
from atproto import (
5+
AtUri,
6+
CAR,
7+
firehose_models,
8+
FirehoseSubscribeReposClient,
9+
models,
10+
parse_subscribe_repos_message,
11+
)
512
from atproto.exceptions import FirehoseError
613

714
from server.database import SubscriptionState
@@ -15,21 +22,21 @@
1522

1623

1724
def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defaultdict:
18-
operation_by_type = defaultdict(lambda: {'created': [], 'deleted': []})
25+
operation_by_type = defaultdict(lambda: {"created": [], "deleted": []})
1926

2027
car = CAR.from_bytes(commit.blocks)
2128
for op in commit.ops:
22-
if op.action == 'update':
29+
if op.action == "update":
2330
# we are not interested in updates
2431
continue
2532

26-
uri = AtUri.from_str(f'at://{commit.repo}/{op.path}')
33+
uri = AtUri.from_str(f"at://{commit.repo}/{op.path}")
2734

28-
if op.action == 'create':
35+
if op.action == "create":
2936
if not op.cid:
3037
continue
3138

32-
create_info = {'uri': str(uri), 'cid': str(op.cid), 'author': commit.repo}
39+
create_info = {"uri": str(uri), "cid": str(op.cid), "author": commit.repo}
3340

3441
record_raw_data = car.blocks.get(op.cid)
3542
if not record_raw_data:
@@ -40,12 +47,16 @@ def _get_ops_by_type(commit: models.ComAtprotoSyncSubscribeRepos.Commit) -> defa
4047
continue
4148

4249
for record_type, record_nsid in _INTERESTED_RECORDS.items():
43-
if uri.collection == record_nsid and models.is_record_type(record, record_type):
44-
operation_by_type[record_nsid]['created'].append({'record': record, **create_info})
50+
if uri.collection == record_nsid and models.is_record_type(
51+
record, record_type
52+
):
53+
operation_by_type[record_nsid]["created"].append(
54+
{"record": record, **create_info}
55+
)
4556
break
4657

47-
if op.action == 'delete':
48-
operation_by_type[uri.collection]['deleted'].append({'uri': str(uri)})
58+
if op.action == "delete":
59+
operation_by_type[uri.collection]["deleted"].append({"uri": str(uri)})
4960

5061
return operation_by_type
5162

@@ -57,7 +68,7 @@ def run(name, operations_callback, stream_stop_event=None):
5768
except FirehoseError as e:
5869
if logger.level == logging.DEBUG:
5970
raise e
60-
logger.error(f'Firehose error: {e}. Reconnecting to the firehose.')
71+
logger.error(f"Firehose error: {e}. Reconnecting to the firehose.")
6172

6273

6374
def _run(name, operations_callback, stream_stop_event=None):
@@ -69,9 +80,6 @@ def _run(name, operations_callback, stream_stop_event=None):
6980

7081
client = FirehoseSubscribeReposClient(params)
7182

72-
if not state:
73-
SubscriptionState.create(service=name, cursor=0)
74-
7583
def on_message_handler(message: firehose_models.MessageFrame) -> None:
7684
# stop on next message if requested
7785
if stream_stop_event and stream_stop_event.is_set():
@@ -84,9 +92,15 @@ def on_message_handler(message: firehose_models.MessageFrame) -> None:
8492

8593
# update stored state every ~1k events
8694
if commit.seq % 1000 == 0: # lower value could lead to performance issues
87-
logger.debug(f'Updated cursor for {name} to {commit.seq}')
88-
client.update_params(models.ComAtprotoSyncSubscribeRepos.Params(cursor=commit.seq))
89-
SubscriptionState.update(cursor=commit.seq).where(SubscriptionState.service == name).execute()
95+
logger.debug(f"Updated cursor for {name} to {commit.seq}")
96+
client.update_params(
97+
models.ComAtprotoSyncSubscribeRepos.Params(cursor=commit.seq)
98+
)
99+
SubscriptionState.insert(service=name, cursor=commit.seq).on_conflict(
100+
conflict_target=(SubscriptionState.service,),
101+
action="UPDATE",
102+
update={SubscriptionState.cursor: commit.seq},
103+
).execute()
90104

91105
if not commit.blocks:
92106
return

0 commit comments

Comments
 (0)