Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,84 @@ exports.Pool = Pool;

exports.PoolCluster = PoolCluster;

exports.createServer = function (handler) {
const _serverHandlerKeys = ['query', 'ping', 'quit', 'init_db', 'auth'];

function _hasHandlerKeys(obj) {
return _serverHandlerKeys.some((k) => typeof obj[k] === 'function');
}

function _wrapAuth(authHandler) {
return function (params, cb) {
Promise.resolve()
.then(() => authHandler(params))
.then(() => cb(null))
.catch((err) =>
cb(null, { message: err.message, code: err.code || 1045 })
);
};
}

function _buildHandshakeArgs(handlers) {
const args = {
protocolVersion: 10,
serverVersion: handlers.serverVersion || 'mysql2-server',
connectionId: Math.floor(Math.random() * 1000000),
statusFlags: 2,
characterSet: 8,
capabilityFlags: 0xffffff,
};
if (handlers.auth) {
args.authCallback = _wrapAuth(handlers.auth);
}
return args;
}

exports.createServer = function (opts = {}) {
const Server = require('./lib/server.js');
const s = new Server();
if (handler) {
s.on('connection', handler);
const Commands = require('./lib/commands/index.js');
const { buildHandleCommand } = require('./lib/commands/server/index.js');

if (typeof opts === 'function') {
const fn = opts;
const s = new Server({ encoding: 'cesu8' });
s.on('connection', (conn) => {
conn.on('error', () => {});
const result = fn(conn);
if (!result || typeof result !== 'object' || !_hasHandlerKeys(result)) {
return;
}
const handlers = result;
const encoding = handlers.encoding || 'cesu8';
conn.serverConfig = { encoding };
conn.config.serverOptions = Object.assign({}, conn.config.serverOptions, {
handleCommand: buildHandleCommand(handlers),
encoding,
});
conn.addCommand(
new Commands.ServerHandshake(_buildHandshakeArgs(handlers))
);
});
return s;
}

if (_hasHandlerKeys(opts)) {
const handleCommand = buildHandleCommand(opts);
const encoding = opts.encoding || 'cesu8';
const s = new Server({ handleCommand, encoding });
s.on('connection', (conn) => {
conn.on('error', () => {});
conn.serverConfig = { encoding };
conn.addCommand(new Commands.ServerHandshake(_buildHandshakeArgs(opts)));
});
return s;
}

const s = new Server({
handleCommand: opts.handleCommand,
encoding: opts.encoding || 'cesu8',
});
if (opts.onConnection) {
s.on('connection', opts.onConnection);
}
return s;
};
Expand Down
42 changes: 28 additions & 14 deletions lib/base/connection.js
Original file line number Diff line number Diff line change
Expand Up @@ -515,14 +515,26 @@
);
}
}
if (
!this._command &&
this.config.isServer &&
this.config.serverOptions?.handleCommand
) {
const commandCode = packet.peekByte();
this._command = this.config.serverOptions.handleCommand(commandCode);
}
if (!this._command) {
const marker = packet.peekByte();
// If it's an Err Packet, we should use it.
if (marker === 0xff) {
const error = Packets.Error.fromPacket(packet);
this.protocolError(error.message, error.code);
} else if (this.config.isServer && !this.config.serverOptions?.handleCommand) {

Check failure on line 531 in lib/base/connection.js

View workflow job for this annotation

GitHub Actions / lint-js

Replace `this.config.isServer·&&·!this.config.serverOptions?.handleCommand` with `⏎········this.config.isServer·&&⏎········!this.config.serverOptions?.handleCommand⏎······`
this.protocolError(
'No handleCommand configured for server connection. ' +
'Provide a handleCommand option to createServer() to handle client commands.',
'PROTOCOL_UNEXPECTED_PACKET'
);
} else {
// Otherwise, it means it's some other unexpected packet.
this.protocolError(
'Unexpected packet while no commands in the queue',
'PROTOCOL_UNEXPECTED_PACKET'
Expand Down Expand Up @@ -1016,27 +1028,31 @@
// ===================================
// outgoing server connection methods
// ===================================

get _serverEncoding() {
return (
this.config.serverOptions?.encoding ||
(this.serverConfig && this.serverConfig.encoding) ||
'cesu8'
);
}

writeColumns(columns) {
this.writePacket(Packets.ResultSetHeader.toPacket(columns.length));
columns.forEach((column) => {
this.writePacket(
Packets.ColumnDefinition.toPacket(column, this.serverConfig.encoding)
Packets.ColumnDefinition.toPacket(column, this._serverEncoding)
);
});
this.writeEof();
}

// row is array of columns, not hash
writeTextRow(column) {
this.writePacket(
Packets.TextRow.toPacket(column, this.serverConfig.encoding)
);
this.writePacket(Packets.TextRow.toPacket(column, this._serverEncoding));
}

writeBinaryRow(column) {
this.writePacket(
Packets.BinaryRow.toPacket(column, this.serverConfig.encoding)
);
this.writePacket(Packets.BinaryRow.toPacket(column, this._serverEncoding));
}

writeTextResult(rows, columns, binary = false) {
Expand All @@ -1061,13 +1077,11 @@
if (!args) {
args = { affectedRows: 0 };
}
this.writePacket(Packets.OK.toPacket(args, this.serverConfig.encoding));
this.writePacket(Packets.OK.toPacket(args, this._serverEncoding));
}

writeError(args) {
// if we want to send error before initial hello was sent, use default encoding
const encoding = this.serverConfig ? this.serverConfig.encoding : 'cesu8';
this.writePacket(Packets.Error.toPacket(args, encoding));
this.writePacket(Packets.Error.toPacket(args, this._serverEncoding));
}

serverHandshake(args) {
Expand Down
6 changes: 4 additions & 2 deletions lib/commands/auth_switch.js
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,11 @@ function authSwitchRequest(packet, connection, command) {

const authPlugin = getAuthPlugin(pluginName, connection);
if (!authPlugin) {
throw new Error(
`Server requests authentication using unknown plugin ${pluginName}. See ${'TODO: add plugins doco here'} on how to configure or author authentication plugins.`
const err = new Error(
`Server requests authentication using unknown plugin ${pluginName}.`
);
connection.emit('error', err);
return;
}
connection._authPlugin = authPlugin({ connection, command });
Promise.resolve(connection._authPlugin(pluginData))
Expand Down
22 changes: 4 additions & 18 deletions lib/commands/client_handshake.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ const Command = require('./command.js');
const Packets = require('../packets/index.js');
const ClientConstants = require('../constants/client.js');
const CharsetToEncoding = require('../constants/charset_encodings.js');

// TODO: refactor to use plugins
// need to coordinate with ChangeUser command,
// currently it uses sync calculateNativePasswordAuthToken method from here
const auth41 = require('../auth_41.js');
const { getAuthPlugin } = require('./auth_switch.js');
const {
Expand Down Expand Up @@ -60,27 +64,17 @@ class ClientHandshake extends Command {
}
this.user = connection.config.user;
this.password = connection.config.password;
// "password1" is an alias to the original "password" value
// to make it easier to integrate multi-factor authentication
this.password1 = connection.config.password;
// "password2" and "password3" are the 2nd and 3rd factor authentication
// passwords, which can be undefined depending on the authentication
// plugin being used
this.password2 = connection.config.password2;
this.password3 = connection.config.password3;
this.passwordSha1 = connection.config.passwordSha1;
this.database = connection.config.database;
this.authPluginName = this.handshake.authPluginName;

// Optimization: Try to use the server's preferred authentication method
// to avoid an unnecessary auth switch roundtrip
const serverAuthMethod = this.handshake.authPluginName;
const isSecureConnection =
connection.config.ssl || connection.config.socketPath;

// Combine auth plugin data for easier handling
// Note: authPluginData2 can include a trailing NUL byte when PLUGIN_AUTH is set
// We must ensure exactly 20 bytes for the scramble
const authPluginData =
this.handshake.authPluginData1 && this.handshake.authPluginData2
? Buffer.concat([
Expand All @@ -89,8 +83,6 @@ class ClientHandshake extends Command {
]).slice(0, 20)
: Buffer.alloc(20);

// Check if user has custom auth plugin or legacy handler for the server-advertised method
// If so, we must not bypass the auth switch flow with our built-in implementation
const hasCustomAuthPlugin =
connection.config.authPlugins &&
Object.prototype.hasOwnProperty.call(
Expand All @@ -100,8 +92,6 @@ class ClientHandshake extends Command {
const hasLegacyAuthSwitchHandler =
typeof connection.config.authSwitchHandler === 'function';

// Determine which auth method to use
// Try to use server's preferred method if we can, otherwise fallback to native
const canUseDirectAuth =
!hasCustomAuthPlugin &&
!hasLegacyAuthSwitchHandler &&
Expand All @@ -113,7 +103,6 @@ class ClientHandshake extends Command {
? serverAuthMethod
: 'mysql_native_password';

// Calculate the auth token for the chosen method
const authToken = this.calculateAuthToken(
clientAuthMethod,
this.password,
Expand Down Expand Up @@ -144,9 +133,6 @@ class ClientHandshake extends Command {
});
connection.writePacket(handshakeResponse.toPacket());

// If we used a non-native auth method in the initial handshake response,
// we need to prepare for potential AuthMoreData packets by creating
// the appropriate auth plugin instance
if (clientAuthMethod !== 'mysql_native_password') {
this.initializeAuthPlugin(clientAuthMethod, authPluginData, connection);
}
Expand Down
3 changes: 3 additions & 0 deletions lib/commands/command.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class Command extends EventEmitter {
if (!this.next) {
this.next = this.start;
connection._resetSequenceId();
if (connection.config.isServer && packet) {
connection._bumpSequenceId(1);
}
}
if (packet && packet.isError()) {
const err = packet.asError(connection.clientEncoding);
Expand Down
56 changes: 56 additions & 0 deletions lib/commands/server/index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
'use strict';

const CommandCode = require('../../constants/commands.js');
const ServerQuery = require('./query.js');
const ServerPing = require('./ping.js');
const ServerQuit = require('./quit.js');
const ServerInitDb = require('./init_db.js');
const { sendError } = require('./send_result.js');
const Command = require('../command.js');

function defaultPing() {}
function defaultQuit() {}
function defaultInitDb() {}

function buildHandleCommand(handlers) {
const queryHandler = handlers.query;
const pingHandler = handlers.ping || defaultPing;
const quitHandler = handlers.quit || defaultQuit;
const initDbHandler = handlers.init_db || defaultInitDb;
const fallback = handlers.handleCommand;

return function handleCommand(commandCode) {
switch (commandCode) {
case CommandCode.QUERY:
if (queryHandler) {
return new ServerQuery(queryHandler);
}
break;
case CommandCode.PING:
return new ServerPing(pingHandler);
case CommandCode.QUIT:
return new ServerQuit(quitHandler);
case CommandCode.INIT_DB:
return new ServerInitDb(initDbHandler);
}

if (fallback) {
return fallback(commandCode);
}

const cmd = new Command();
cmd.start = function (_packet, connection) {
sendError(connection, new Error('Command not supported'));
return null;
};
return cmd;
};
}

module.exports = {
ServerQuery,
ServerPing,
ServerQuit,
ServerInitDb,
buildHandleCommand,
};
50 changes: 50 additions & 0 deletions lib/commands/server/init_db.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
'use strict';

const Command = require('../command.js');
const { sendResult, sendError } = require('./send_result.js');

class ServerInitDb extends Command {
constructor(handler) {
super();
this._handler = handler;
}

start(packet, connection) {
packet.readInt8();
const encoding =
(connection.clientHelloReply && connection.clientHelloReply.encoding) ||
'utf8';
const schemaName = packet.readString(undefined, encoding);
let result;
try {
result = this._handler(schemaName);
} catch (err) {
sendError(connection, err);
return null;
}
if (result && typeof result.then === 'function') {
result
.then(() => sendResult(connection, undefined))
.catch((err) => sendError(connection, err))
.then(() => {
this.next = null;
this.emit('end');
connection._command = connection._commands.shift();
if (connection._command) {
connection.sequenceId = 0;
connection.compressedSequenceId = 0;
connection.handlePacket();
}
});
return ServerInitDb.prototype._awaitResult;
}
sendResult(connection, undefined);
return null;
}

_awaitResult() {
return ServerInitDb.prototype._awaitResult;
}
}

module.exports = ServerInitDb;
Loading
Loading