diff --git a/docs/server/hooks.md b/docs/server/hooks.md index cffb0045..5ccf4509 100644 --- a/docs/server/hooks.md +++ b/docs/server/hooks.md @@ -49,7 +49,6 @@ By way of illustration, if a user isn’t allowed to connect: Just throw an erro | `beforeBroadcastStateless` | Before broadcast a stateless message | [Read more](/server/hooks#before-broadcast-stateless) | | `afterUnloadDocument` | When a document is closed | [Read more](/server/hooks#after-unload-document) | - ## Usage ```js @@ -759,7 +758,6 @@ const server = new Server({ server.listen(); ``` - ### onStateless The `onStateless` hooks are called after the server has received a stateless message. It should return a Promise. @@ -796,6 +794,37 @@ const server = new Server({ server.listen() ``` +### beforeSync + +The `beforeSync` hooks are called before a sync message is handled. This is useful if you want to inspect the sync message that will be applied to the document. + +**Hook payload** + +```js +const data = { + documentName: string, + document: Document, + // The y-protocols/sync message type + type: number, + // The payload of the y-protocols/sync message + payload: Uint8Array, +} +``` + +**Example** + +```js +import { Server } from '@hocuspocus/server' + +const server = new Server({ + async beforeSync({ payload, document, documentName, type }) { + console.log(`Server will handle a sync message: "${payload}"!`) + }, +}) + +server.listen() +``` + ### beforeBroadcastStateless The `beforeBroadcastStateless` hooks are called before the server broadcast a stateless message. diff --git a/packages/server/src/ClientConnection.ts b/packages/server/src/ClientConnection.ts index 7f4fd233..f8d16df1 100644 --- a/packages/server/src/ClientConnection.ts +++ b/packages/server/src/ClientConnection.ts @@ -16,6 +16,7 @@ import { OutgoingMessage } from './OutgoingMessage.ts' import type { ConnectionConfiguration, beforeHandleMessagePayload, + beforeSyncPayload, onDisconnectPayload, } from './types.ts' import { @@ -199,6 +200,20 @@ export class ClientConnection { return this.hooks('beforeHandleMessage', beforeHandleMessagePayload) }) + instance.beforeSync((connection, payload) => { + const beforeSyncPayload: beforeSyncPayload = { + clientsCount: document.getConnectionsCount(), + context: hookPayload.context, + document, + documentName: document.name, + connection, + type: payload.type, + payload: payload.payload, + } + + return this.hooks('beforeSync', beforeSyncPayload) + }) + return instance } diff --git a/packages/server/src/Connection.ts b/packages/server/src/Connection.ts index 63ca4128..25adfbc0 100644 --- a/packages/server/src/Connection.ts +++ b/packages/server/src/Connection.ts @@ -8,7 +8,7 @@ import type Document from './Document.ts' import { IncomingMessage } from './IncomingMessage.ts' import { MessageReceiver } from './MessageReceiver.ts' import { OutgoingMessage } from './OutgoingMessage.ts' -import type { onStatelessPayload } from './types.ts' +import type { beforeSyncPayload, onStatelessPayload } from './types.ts' export class Connection { @@ -23,6 +23,7 @@ export class Connection { callbacks = { onClose: [(document: Document, event?: CloseEvent) => {}], beforeHandleMessage: (connection: Connection, update: Uint8Array) => Promise.resolve(), + beforeSync: (connection: Connection, payload: Pick) => Promise.resolve(), statelessCallback: (payload: onStatelessPayload) => Promise.resolve(), } @@ -81,6 +82,15 @@ export class Connection { return this } + /** + * Set a callback that will be triggered before a sync message is handled + */ + beforeSync(callback: (connection: Connection, payload: Pick) => Promise): Connection { + this.callbacks.beforeSync = callback + + return this + } + /** * Send the given message */ diff --git a/packages/server/src/Hocuspocus.ts b/packages/server/src/Hocuspocus.ts index 911c6262..e01145d2 100644 --- a/packages/server/src/Hocuspocus.ts +++ b/packages/server/src/Hocuspocus.ts @@ -49,6 +49,7 @@ export class Hocuspocus { onConnect: () => new Promise(r => r(null)), connected: () => new Promise(r => r(null)), beforeHandleMessage: () => new Promise(r => r(null)), + beforeSync: () => new Promise(r => r(null)), beforeBroadcastStateless: () => new Promise(r => r(null)), onStateless: () => new Promise(r => r(null)), onChange: () => new Promise(r => r(null)), @@ -111,6 +112,7 @@ export class Hocuspocus { afterLoadDocument: this.configuration.afterLoadDocument, beforeHandleMessage: this.configuration.beforeHandleMessage, beforeBroadcastStateless: this.configuration.beforeBroadcastStateless, + beforeSync: this.configuration.beforeSync, onStateless: this.configuration.onStateless, onChange: this.configuration.onChange, onStoreDocument: this.configuration.onStoreDocument, diff --git a/packages/server/src/IncomingMessage.ts b/packages/server/src/IncomingMessage.ts index a41919b3..7ac55795 100644 --- a/packages/server/src/IncomingMessage.ts +++ b/packages/server/src/IncomingMessage.ts @@ -49,6 +49,13 @@ export class IncomingMessage { return readVarUint8Array(this.decoder) } + peekVarUint8Array() { + const { pos } = this.decoder + const result = readVarUint8Array(this.decoder) + this.decoder.pos = pos + return result + } + readVarUint() { return readVarUint(this.decoder) } diff --git a/packages/server/src/MessageReceiver.ts b/packages/server/src/MessageReceiver.ts index 698d26c4..61f117d0 100644 --- a/packages/server/src/MessageReceiver.ts +++ b/packages/server/src/MessageReceiver.ts @@ -27,7 +27,7 @@ export class MessageReceiver { this.defaultTransactionOrigin = defaultTransactionOrigin } - public apply(document: Document, connection?: Connection, reply?: (message: Uint8Array) => void) { + public async apply(document: Document, connection?: Connection, reply?: (message: Uint8Array) => void) { const { message } = this const type = message.readVarUint() const emptyMessageLength = message.length @@ -36,7 +36,7 @@ export class MessageReceiver { case MessageType.Sync: case MessageType.SyncReply: { message.writeVarUint(MessageType.Sync) - this.readSyncMessage(message, document, connection, reply, type !== MessageType.SyncReply) + await this.readSyncMessage(message, document, connection, reply, type !== MessageType.SyncReply) if (message.length > emptyMessageLength + 1) { if (reply) { @@ -55,7 +55,7 @@ export class MessageReceiver { break } case MessageType.Awareness: { - applyAwarenessUpdate(document.awareness, message.readVarUint8Array(), connection?.webSocket) + await applyAwarenessUpdate(document.awareness, message.readVarUint8Array(), connection?.webSocket) break } @@ -101,9 +101,16 @@ export class MessageReceiver { } } - readSyncMessage(message: IncomingMessage, document: Document, connection?: Connection, reply?: (message: Uint8Array) => void, requestFirstSync = true) { + async readSyncMessage(message: IncomingMessage, document: Document, connection?: Connection, reply?: (message: Uint8Array) => void, requestFirstSync = true) { const type = message.readVarUint() + if (connection) { + await connection.callbacks.beforeSync(connection, { + type, + payload: message.peekVarUint8Array(), + }) + } + switch (type) { case messageYjsSyncStep1: { readSyncStep1(message.decoder, message.encoder, document) diff --git a/packages/server/src/types.ts b/packages/server/src/types.ts index 0dc68d8d..6b546096 100644 --- a/packages/server/src/types.ts +++ b/packages/server/src/types.ts @@ -44,6 +44,7 @@ export interface Extension { onLoadDocument?(data: onLoadDocumentPayload): Promise; afterLoadDocument?(data: afterLoadDocumentPayload): Promise; beforeHandleMessage?(data: beforeHandleMessagePayload): Promise; + beforeSync?(data: beforeSyncPayload): Promise; beforeBroadcastStateless?(data: beforeBroadcastStatelessPayload): Promise; onStateless?(payload: onStatelessPayload): Promise; onChange?(data: onChangePayload): Promise; @@ -69,6 +70,7 @@ export type HookName = 'afterLoadDocument' | 'beforeHandleMessage' | 'beforeBroadcastStateless' | + 'beforeSync' | 'onStateless' | 'onChange' | 'onStoreDocument' | @@ -92,6 +94,7 @@ export type HookPayloadByName = { afterLoadDocument: afterLoadDocumentPayload, beforeHandleMessage: beforeHandleMessagePayload, beforeBroadcastStateless: beforeBroadcastStatelessPayload, + beforeSync: beforeSyncPayload, onStateless: onStatelessPayload, onChange: onChangePayload, onStoreDocument: onStoreDocumentPayload, @@ -250,6 +253,28 @@ export interface beforeHandleMessagePayload { connection: Connection } +export interface beforeSyncPayload { + clientsCount: number, + context: any, + document: Document, + documentName: string, + connection: Connection, + /** + * The y-protocols/sync message type + * @example + * 0: SyncStep1 + * 1: SyncStep2 + * 2: YjsUpdate + * + * @see https://github.com/yjs/y-protocols/blob/master/sync.js#L13-L40 + */ + type: number, + /** + * The payload of the y-sync message. + */ + payload: Uint8Array, +} + export interface beforeBroadcastStatelessPayload { document: Document, documentName: string, diff --git a/tests/server/beforeSync.ts b/tests/server/beforeSync.ts new file mode 100644 index 00000000..3ca4288d --- /dev/null +++ b/tests/server/beforeSync.ts @@ -0,0 +1,156 @@ +import test from 'ava' +import { newHocuspocus, newHocuspocusProvider } from '../utils/index.ts' +import { retryableAssertion } from '../utils/retryableAssertion.ts' + +test('beforeSync gets called in proper order', async t => { + await new Promise(async resolve => { + const mockContext = { + user: 123, + } + + let callNumber = 0 + + const server = await newHocuspocus({ + async onConnect() { + return mockContext + }, + async beforeSync({ document, context, payload }) { + t.deepEqual(context, mockContext) + + callNumber += 1 + + if (callNumber === 2) { + resolve('done') + } + }, + async onChange({ context, document }) { + t.deepEqual(context, mockContext) + + const value = document.getArray('foo').get(0) + + t.is(value, 'bar') + }, + }) + + const provider = newHocuspocusProvider(server, { + onSynced() { + provider.document.getArray('foo').insert(0, ['bar']) + }, + }) + }) +}) + +test('beforeSync callback is called for every sync', async t => { + let onConnectCount = 0 + let updateCount = 0 + let syncstep1Count = 0 + let syncstep2Count = 0 + + await new Promise(async resolve => { + const server = await newHocuspocus({ + async onConnect() { + onConnectCount += 1 + }, + async beforeSync({ type }) { + if (type === 0){ + syncstep1Count += 1 + } else if (type === 1) { + syncstep2Count += 1 + } else if (type === 2) { + updateCount += 1 + } + }, + }) + + await Promise.all([ + new Promise(done => { + newHocuspocusProvider(server, { + onClose() { + t.fail() + }, + onSynced() { + done('done') + }, + }) + }), + new Promise(done => { + newHocuspocusProvider(server, { + onClose() { + t.fail() + }, + onSynced() { + done('done') + }, + }) + }), + ]) + + resolve('done') + }) + + await retryableAssertion(t, tt => { + tt.is(onConnectCount, 2) + tt.is(syncstep1Count, 2) + tt.is(syncstep2Count, 2) + tt.is(updateCount, 0) + }) +}) + + +test('beforeSync callback is called on every update', async t => { + let onConnectCount = 0 + let updateCount = 0 + let syncstep1Count = 0 + let syncstep2Count = 0 + + + await new Promise(async resolve => { + const server = await newHocuspocus({ + async onConnect() { + onConnectCount += 1 + }, + async beforeSync({ type }) { + if (type === 0){ + syncstep1Count += 1 + } else if (type === 1) { + syncstep2Count += 1 + } else if (type === 2) { + updateCount += 1 + } + }, + }) + + await Promise.all([ + new Promise(done => { + newHocuspocusProvider(server, { + onClose() { + t.fail() + }, + onSynced() { + done('done') + }, + }) + }), + new Promise(done => { + const provider = newHocuspocusProvider(server, { + onClose() { + t.fail() + }, + onSynced() { + provider.document.getArray('foo').insert(0, ['bar']) + done('done') + }, + }) + }), + ]) + + resolve('done') + }) + + await retryableAssertion(t, tt => { + tt.is(onConnectCount, 2) + tt.is(syncstep1Count, 2) + tt.is(syncstep2Count, 2) + tt.is(updateCount, 1) + }) +})