Skip to content
This repository was archived by the owner on Jul 25, 2021. It is now read-only.

Commit 4d1b0b2

Browse files
authored
feat: gateway events for messages (#44)
* feat: add new methods to GatewayController * feat: create message create/delete gateway packets * feat: getChannel middleware * refactor: use getChannel middleware * refactor: send gateway events on message create/delete * test: getChannel middleware * test: GatewayController broadcast/sendMessage * refactor: use getMessages middleware for getMessages route * refactor: add generic type to sendMessage
1 parent 33c1c58 commit 4d1b0b2

14 files changed

Lines changed: 302 additions & 89 deletions

File tree

src/controllers/GatewayController.ts

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@ export default class GatewayController {
5353
}
5454
}
5555

56+
public async broadcast<T extends GatewayPacket>(message: T) {
57+
await this.gatewayService.send([...this.authenticatedClients.keys()], message);
58+
}
59+
60+
public async sendMessage<T extends GatewayPacket>(to: string[], message: T) {
61+
const recipients = [...this.authenticatedClients.entries()].filter(([,id]) => to.includes(id)).map(entry => entry[0]);
62+
await this.gatewayService.send(recipients, message);
63+
}
64+
5665
public async onAuthenticate(ws: WebSocket, packet: IdentifyGatewayPacket) {
5766
if (this.authenticatedClients.has(ws)) throw new GatewayError('Already authenticated!');
5867
const { token } = packet.data;

src/controllers/MessageController.ts

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,32 @@ import { AuthenticatedResponse } from '../routes/middleware/getUser';
44
import MessageService from '../services/MessageService';
55
import { AccountType } from '../entities/User';
66
import { HttpCode } from '../util/errors';
7+
import { ChannelResponse } from '../routes/middleware/getChannel';
8+
import GatewayController from './GatewayController';
9+
import { MessageCreateGatewayPacket, GatewayPacketType, MessageDeleteGatewayPacket } from '../util/gateway';
10+
import { EventChannel } from '../entities/Channel';
711

812
@injectable()
913
export class MessageController {
1014
private readonly messageService: MessageService;
15+
private readonly gatewayController: GatewayController;
1116

12-
public constructor(@inject(MessageService) messageService: MessageService) {
17+
public constructor(@inject(MessageService) messageService: MessageService, @inject(GatewayController) gatewayController: GatewayController) {
1318
this.messageService = messageService;
19+
this.gatewayController = gatewayController;
1420
}
1521

16-
public async createMessage(req: Request & { params: { channelID: string } }, res: AuthenticatedResponse, next: NextFunction): Promise<void> {
22+
public async createMessage(req: Request, res: ChannelResponse, next: NextFunction): Promise<void> {
1723
try {
18-
const message = await this.messageService.createMessage({ ...req.body, channelID: req.params.channelID, authorID: res.locals.user.id });
24+
const message = await this.messageService.createMessage({ ...req.body, channel: res.locals.channel, author: res.locals.user });
25+
if (res.locals.channel instanceof EventChannel) {
26+
await this.gatewayController.broadcast<MessageCreateGatewayPacket>({
27+
type: GatewayPacketType.MessageCreate,
28+
data: {
29+
message
30+
}
31+
});
32+
}
1933
res.json({ message });
2034
} catch (error) {
2135
next(error);
@@ -31,10 +45,10 @@ export class MessageController {
3145
}
3246
}
3347

34-
public async getMessages(req: Request & { params: { channelID: string; page: number } }, res: AuthenticatedResponse, next: NextFunction): Promise<void> {
48+
public async getMessages(req: Request & { params: { channelID: string; page: number } }, res: ChannelResponse, next: NextFunction): Promise<void> {
3549
try {
3650
const messages = await this.messageService.getMessages({
37-
channelID: req.params.channelID,
51+
channel: res.locals.channel,
3852
page: Number(req.query.page),
3953
count: 50
4054
});
@@ -44,13 +58,22 @@ export class MessageController {
4458
}
4559
}
4660

47-
public async deleteMessage(req: Request & { params: { channelID: string; messageID: string } }, res: AuthenticatedResponse, next: NextFunction): Promise<void> {
61+
public async deleteMessage(req: Request & { params: { messageID: string } }, res: ChannelResponse, next: NextFunction): Promise<void> {
4862
try {
4963
await this.messageService.deleteMessage({
5064
id: req.params.messageID,
51-
channelID: req.params.channelID,
65+
channelID: res.locals.channel.id,
5266
authorID: res.locals.user.accountType === AccountType.Admin ? undefined : res.locals.user.id
5367
});
68+
if (res.locals.channel instanceof EventChannel) {
69+
await this.gatewayController.broadcast<MessageDeleteGatewayPacket>({
70+
type: GatewayPacketType.MessageDelete,
71+
data: {
72+
messageID: req.params.messageID,
73+
channelID: res.locals.channel.id
74+
}
75+
});
76+
}
5477
res.status(HttpCode.NoContent).end();
5578
} catch (error) {
5679
next(error);

src/routes/MessageRoutes.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import { Router } from 'express';
22
import { inject, injectable } from 'tsyringe';
33
import { getUser, isVerified } from './middleware';
44
import { MessageController } from '../controllers/MessageController';
5+
import getChannel from './middleware/getChannel';
56

67
@injectable()
78
export class MessageRoutes {
@@ -12,9 +13,9 @@ export class MessageRoutes {
1213
}
1314

1415
public routes(router: Router): void {
15-
router.post('/channels/:channelID/messages', getUser, isVerified, this.messageController.createMessage.bind(this.messageController));
16-
router.get('/channels/:channelID/messages', getUser, isVerified, this.messageController.getMessages.bind(this.messageController));
16+
router.post('/channels/:channelID/messages', getUser, isVerified, getChannel, this.messageController.createMessage.bind(this.messageController));
17+
router.get('/channels/:channelID/messages', getUser, isVerified, getChannel, this.messageController.getMessages.bind(this.messageController));
1718
router.get('/channels/:channelID/messages/:messageID', getUser, isVerified, this.messageController.getMessage.bind(this.messageController));
18-
router.delete('/channels/:channelID/messages/:messageID', getUser, isVerified, this.messageController.deleteMessage.bind(this.messageController));
19+
router.delete('/channels/:channelID/messages/:messageID', getUser, isVerified, getChannel, this.messageController.deleteMessage.bind(this.messageController));
1920
}
2021
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import { NextFunction, Request } from 'express';
2+
import { Channel, EventChannel } from '../../entities/Channel';
3+
import { APIError, HttpCode } from '../../util/errors';
4+
import { AuthenticatedResponse } from './getUser';
5+
import { container } from 'tsyringe';
6+
import ChannelService from '../../services/ChannelService';
7+
8+
enum GetChannelError {
9+
NotAllowed = 'You are not allowed to view this channel',
10+
NotFound = 'Channel not found'
11+
}
12+
13+
export type ChannelResponse = AuthenticatedResponse & { locals: { channel: Channel } };
14+
15+
export default async function getChannel(req: Request, res: AuthenticatedResponse, next: NextFunction) {
16+
const channelService = container.resolve(ChannelService);
17+
if (!req.params.channelID) return next(new APIError(HttpCode.NotFound, GetChannelError.NotFound));
18+
const channel = await channelService.findOne({ id: req.params.channelID });
19+
if (!channel) return next(new APIError(HttpCode.NotFound, GetChannelError.NotFound));
20+
21+
if (channel instanceof EventChannel) {
22+
(res as ChannelResponse).locals.channel = channel;
23+
}
24+
next();
25+
}

src/routes/middleware/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
export { default as getUser } from './getUser';
22
export { default as isVerified } from './isVerified';
33
export { default as isAdmin } from './isAdmin';
4+
export { default as getChannel } from './getChannel';

src/services/ChannelService.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import { singleton } from 'tsyringe';
2+
import { getRepository, FindOneOptions, FindConditions } from 'typeorm';
3+
import { Channel } from '../entities/Channel';
4+
5+
@singleton()
6+
export default class ChannelService {
7+
public async findOne(findConditions: FindConditions<Channel>, options?: FindOneOptions) {
8+
return getRepository(Channel).findOne(findConditions, options);
9+
}
10+
}

src/services/MessageService.ts

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import { getRepository } from 'typeorm';
44
import { APIError, formatValidationErrors, HttpCode } from '../util/errors';
55
import { validateOrReject } from 'class-validator';
66
import { Channel } from '../entities/Channel';
7+
import { User } from '../entities/User';
78

8-
type MessageCreationData = Omit<APIMessage, 'id'>;
9+
type MessageCreationData = Omit<APIMessage, 'id' | 'channelID' | 'authorID'> & { channel: Channel; author: User };
910

1011
enum GetMessageError {
1112
NotFound = 'Message not found',
@@ -18,18 +19,14 @@ enum DeleteMessageError {
1819
NotAuthor = 'You are not the author of this message'
1920
}
2021

21-
enum GetMessagesError {
22-
ChannelNotFound = 'Channel does not exist'
23-
}
24-
2522
@singleton()
2623
export default class MessageService {
2724
public async createMessage(data: MessageCreationData): Promise<APIMessage> {
2825
const message = new Message();
2926
message.content = data.content;
3027
message.time = new Date(data.time);
31-
message.author = { id: data.authorID } as any;
32-
message.channel = { id: data.channelID } as any;
28+
message.author = data.author;
29+
message.channel = data.channel;
3330
await validateOrReject(message).catch(e => Promise.reject(formatValidationErrors(e)));
3431
return (await getRepository(Message).save(message)).toJSON();
3532
}
@@ -41,13 +38,11 @@ export default class MessageService {
4138
return message.toJSON();
4239
}
4340

44-
public async getMessages(data: { channelID: string; page?: number; count: number }): Promise<APIMessage[]> {
41+
public async getMessages(data: { channel: Channel; page?: number; count: number }): Promise<APIMessage[]> {
4542
if (!data.page || isNaN(data.page)) data.page = 0;
46-
if (!data.channelID) throw new APIError(HttpCode.NotFound, GetMessagesError.ChannelNotFound);
47-
const channel = await getRepository(Channel).findOneOrFail(data.channelID).catch(() => Promise.reject(new APIError(HttpCode.NotFound, GetMessagesError.ChannelNotFound)));
4843
const messages = await getRepository(Message)
4944
.find({
50-
where: { channel },
45+
where: { channel: data.channel },
5146
order: { time: 'DESC' },
5247
skip: data.count * data.page,
5348
take: data.count

src/util/gateway/index.ts

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
import { APIMessage } from '../../entities/Message';
2+
13
export class GatewayError extends Error {}
24

35
export enum GatewayPacketType {
46
Identify = 'IDENTIFY',
5-
Hello = 'HELLO'
7+
Hello = 'HELLO',
8+
MessageCreate = 'MESSAGE_CREATE',
9+
MessageDelete = 'MESSAGE_DELETE'
610
}
711

812
export interface GatewayPacket {
@@ -19,3 +23,18 @@ export interface IdentifyGatewayPacket extends GatewayPacket {
1923
export interface HelloGatewayPacket extends GatewayPacket {
2024
type: GatewayPacketType.Hello;
2125
}
26+
27+
export interface MessageCreateGatewayPacket extends GatewayPacket {
28+
type: GatewayPacketType.MessageCreate;
29+
data: {
30+
message: APIMessage;
31+
};
32+
}
33+
34+
export interface MessageDeleteGatewayPacket extends GatewayPacket {
35+
type: GatewayPacketType.MessageDelete;
36+
data: {
37+
messageID: string;
38+
channelID: string;
39+
};
40+
}

tests/fixtures/messages.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import users from './users';
66
import events from './events';
77
import { v4 as uuidv4 } from 'uuid';
88

9-
export function createMessage(data: { author?: User; channel?: Channel; content?: string; time?: Date }): [Pick<APIMessage, 'content' | 'authorID' | 'channelID' | 'time'>, Message] {
9+
export function createMessage(data: { author?: User; channel?: Channel; content?: string; time?: Date }): [Pick<APIMessage, 'content' | 'time'> & { author: User; channel: Channel }, Message] {
1010
const concreteData = {
1111
author: data.author ?? users.find(user => user.accountStatus === AccountStatus.Verified)!,
1212
channel: data.channel ?? events[0].channel,
@@ -25,8 +25,8 @@ export function createMessage(data: { author?: User; channel?: Channel; content?
2525
*/
2626
return [
2727
{
28-
authorID: concreteData.author.id,
29-
channelID: concreteData.channel.id,
28+
author: concreteData.author,
29+
channel: concreteData.channel,
3030
content: concreteData.content,
3131
time: message.time.toISOString()
3232
},

tests/integration/controllers/gatewayController.test.ts

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,20 @@ import { UserService } from '../../../src/services/UserService';
55
import { container } from 'tsyringe';
66
import { mock, instance, reset, spy, when, objectContaining, verify, anything } from 'ts-mockito';
77
import * as auth from '../../../src/util/auth';
8+
import GatewayController from '../../../src/controllers/GatewayController';
89

910
let spiedVerifyJWT: typeof auth;
1011
let wss: MockWebSocketServer;
1112
let mockedUserService: UserService;
13+
let gatewayController: GatewayController;
1214

1315
beforeAll(() => {
1416
spiedVerifyJWT = spy(auth);
1517
mockedUserService = mock(UserService);
1618
container.clearInstances();
1719
container.register<UserService>(UserService, { useValue: instance(mockedUserService) });
1820
wss = new MockWebSocketServer();
19-
createGateway(wss);
21+
gatewayController = createGateway(wss);
2022
});
2123

2224
function createWebSocket(): Promise<MockWebSocket> {
@@ -27,18 +29,18 @@ function createWebSocket(): Promise<MockWebSocket> {
2729
});
2830
}
2931

30-
let ws: MockWebSocket;
32+
let sockets: MockWebSocket[];
3133

3234
beforeEach(async () => {
3335
reset(spiedVerifyJWT);
3436
spiedVerifyJWT = spy(auth);
3537
reset(mockedUserService);
36-
ws = await createWebSocket();
38+
sockets = await Promise.all([0, 0].map(() => createWebSocket()));
3739
await wait();
3840
});
3941

4042
afterEach(async () => {
41-
await ws.close();
43+
await Promise.all(sockets.map(ws => ws.close()));
4244
await wait();
4345
});
4446

@@ -51,18 +53,62 @@ function wait() {
5153
describe('GatewayController', () => {
5254
describe('general', () => {
5355
test('Closes for unknown JSON', async () => {
54-
await ws.send(JSON.stringify({ z: 2 }));
55-
await expect(ws.nextMessage).rejects.toThrow();
56+
await sockets[0].send(JSON.stringify({ z: 2 }));
57+
await expect(sockets[0].nextMessage).rejects.toThrow();
5658
});
5759

5860
test('Closes for unknown packet', async () => {
59-
await ws.send(JSON.stringify({ type: -1 }));
60-
await expect(ws.nextMessage).rejects.toThrow();
61+
await sockets[0].send(JSON.stringify({ type: -1 }));
62+
await expect(sockets[0].nextMessage).rejects.toThrow();
6163
});
6264

6365
test('Closes for invalid packet', async () => {
64-
await ws.send('garbage');
65-
await expect(ws.nextMessage).rejects.toThrow();
66+
await sockets[0].send('garbage');
67+
await expect(sockets[0].nextMessage).rejects.toThrow();
68+
});
69+
70+
test('broadcast', async () => {
71+
// Authenticate the first socket
72+
gatewayController.authenticatedClients.set(sockets[0].mirror, '');
73+
74+
const payload = { time: Date.now() };
75+
const stringPayload = JSON.stringify(payload);
76+
77+
await gatewayController.broadcast(payload as any);
78+
await wait();
79+
await expect(sockets[0].nextMessage).resolves.toEqual(stringPayload);
80+
expect(sockets[1].allMessages.length).toEqual(0);
81+
82+
// Authenticate the second socket
83+
gatewayController.authenticatedClients.set(sockets[1].mirror, '');
84+
await gatewayController.broadcast(payload as any);
85+
await expect(sockets[0].nextMessage).resolves.toEqual(stringPayload);
86+
await expect(sockets[1].nextMessage).resolves.toEqual(stringPayload);
87+
});
88+
89+
test('sendTo', async () => {
90+
// Authenticate the first socket
91+
gatewayController.authenticatedClients.set(sockets[0].mirror, 'banana');
92+
gatewayController.authenticatedClients.set(sockets[1].mirror, 'apple');
93+
94+
const payload = { time: Date.now() };
95+
const stringPayload = JSON.stringify(payload);
96+
97+
await gatewayController.sendMessage(['banana'], payload as any);
98+
await expect(sockets[0].nextMessage).resolves.toEqual(stringPayload);
99+
expect(sockets[1].allMessages.length).toEqual(0);
100+
101+
await gatewayController.sendMessage(['apple'], payload as any);
102+
await expect(sockets[1].nextMessage).resolves.toEqual(stringPayload);
103+
expect(sockets[0].allMessages.length).toEqual(1);
104+
105+
await gatewayController.sendMessage(['apple', 'banana'], payload as any);
106+
await expect(sockets[0].nextMessage).resolves.toEqual(stringPayload);
107+
await expect(sockets[1].nextMessage).resolves.toEqual(stringPayload);
108+
109+
await gatewayController.sendMessage(['apple', 'banana', 'ghost'], payload as any);
110+
await expect(sockets[0].nextMessage).resolves.toEqual(stringPayload);
111+
await expect(sockets[1].nextMessage).resolves.toEqual(stringPayload);
66112
});
67113
});
68114

@@ -77,8 +123,8 @@ describe('GatewayController', () => {
77123

78124
when(spiedVerifyJWT.verifyJWT('123')).thenResolve({ id: '456' });
79125
when(mockedUserService.findOne(objectContaining({ id: '456' }))).thenResolve({} as any);
80-
await ws.send(JSON.stringify(payload));
81-
expect(JSON.parse(await ws.nextMessage)).toMatchObject({ type: GatewayPacketType.Hello });
126+
await sockets[0].send(JSON.stringify(payload));
127+
expect(JSON.parse(await sockets[0].nextMessage)).toMatchObject({ type: GatewayPacketType.Hello });
82128
verify(spiedVerifyJWT.verifyJWT('123')).once();
83129
verify(mockedUserService.findOne(objectContaining({ id: '456' }))).once();
84130
});
@@ -92,8 +138,8 @@ describe('GatewayController', () => {
92138
};
93139

94140
when(spiedVerifyJWT.verifyJWT('123')).thenReject(new Error('Test Error'));
95-
await ws.send(JSON.stringify(payload));
96-
await expect(ws.nextMessage).rejects.toThrow();
141+
await sockets[0].send(JSON.stringify(payload));
142+
await expect(sockets[0].nextMessage).rejects.toThrow();
97143
verify(spiedVerifyJWT.verifyJWT('123')).once();
98144
verify(mockedUserService.findOne(anything())).never();
99145
});
@@ -108,8 +154,8 @@ describe('GatewayController', () => {
108154

109155
when(spiedVerifyJWT.verifyJWT('123')).thenResolve({ id: '456' });
110156
when(mockedUserService.findOne(objectContaining({ id: '456' }))).thenReject(new Error('Test Error'));
111-
await ws.send(JSON.stringify(payload));
112-
await expect(ws.nextMessage).rejects.toThrow();
157+
await sockets[0].send(JSON.stringify(payload));
158+
await expect(sockets[0].nextMessage).rejects.toThrow();
113159
verify(spiedVerifyJWT.verifyJWT('123')).once();
114160
verify(mockedUserService.findOne(objectContaining({ id: '456' }))).once();
115161
});

0 commit comments

Comments
 (0)