Skip to content

Commit 41c17d7

Browse files
committed
fix numeric and custom tokens
1 parent 991ba57 commit 41c17d7

File tree

5 files changed

+132
-77
lines changed

5 files changed

+132
-77
lines changed

src/defines.ts

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,8 @@ export interface ParamTypes {
8181
numbered?: Array<"?" | ":" | "$">,
8282
named?: Array<":" | "@" | "$">,
8383
quoted?: Array<":" | "@" | "$">,
84-
// regex is for identifying that it is a param, key is how the token is translated to an object value for the formatter,
85-
// may not be necessary here, we shal see
86-
custom?: Array<{regex: string, key?: (text: string) => string }>
84+
// regex for identifying that it is a param
85+
custom?: Array<string>
8786
}
8887

8988
export interface IdentifyOptions {

src/index.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
2222
}
2323

2424
const result = parse(query, isStrict, dialect, options.identifyTables, options.paramTypes);
25+
const sort = dialect === 'psql' && !options.paramTypes;
2526

2627
return result.body.map((statement) => {
2728
const result: IdentifyResult = {
@@ -31,7 +32,7 @@ export function identify(query: string, options: IdentifyOptions = {}): Identify
3132
type: statement.type,
3233
executionType: statement.executionType,
3334
// we want to sort the postgres params: $1 $2 $3, regardless of the order they appear
34-
parameters: dialect === 'psql' ? statement.parameters.sort() : statement.parameters,
35+
parameters: sort ? statement.parameters.sort() : statement.parameters,
3536
tables: statement.tables || [],
3637
};
3738
return result;

src/tokenizer.ts

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -254,85 +254,85 @@ function scanString(state: State, endToken: Char): Token {
254254
}
255255

256256
function getCustomParam(state: State, paramTypes: ParamTypes): string | null | undefined {
257-
const matches = paramTypes?.custom?.map(({ regex }) => {
258-
const reg = new RegExp(`(?:${regex})`, 'u');
259-
return reg.exec(state.input);
257+
const matches = paramTypes?.custom?.map((regex) => {
258+
const reg = new RegExp(`^(?:${regex})`, 'u');
259+
return reg.exec(state.input.slice(state.start));
260260
}).filter((value) => !!value)[0];
261261

262262
return matches ? matches[0] : null;
263263
}
264264

265-
function scanParameter(state: State, dialect: Dialect, paramTypes?: ParamTypes): Token {
266-
// user has defined wanted param types, so we only evaluate them
267-
if (paramTypes) {
268-
const curCh: any = state.input[0];
269-
let nextChar = peek(state);
270-
let matched = false
265+
function scanCustomParameter(state: State, dialect: Dialect, paramTypes: ParamTypes): Token {
271266

272-
// this could be a named parameter that just starts with a number (ugh)
273-
if (paramTypes.numbered && paramTypes.numbered.length && paramTypes.numbered.includes(curCh)) {
274-
const maybeNumbers = state.input.slice(1, state.input.length);
275-
if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) {
276-
do {
277-
nextChar = read(state);
278-
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
279-
280-
if (nextChar !== null) unread(state);
281-
matched = true;
282-
}
283-
}
284-
285-
if (!matched && paramTypes.named && paramTypes.named.length && paramTypes.named.includes(curCh)) {
286-
if (!isQuotedIdentifier(nextChar, dialect)) {
287-
while (isAlphaNumeric(peek(state))) read(state);
288-
matched = true;
289-
}
290-
}
291-
292-
if (!matched && paramTypes.quoted && paramTypes.quoted.length && paramTypes.quoted.includes(curCh)) {
293-
if (isQuotedIdentifier(nextChar, dialect)) {
294-
const endChars = new Map<string, string>([
295-
['"', '"'],
296-
['[', ']'],
297-
['`', '`']
298-
]);
299-
const quoteChar = read(state) as string;
300-
const end = endChars.get(quoteChar);
301-
// end when we reach the end quote
302-
while ((isAlphaNumeric(peek(state)) || peek(state) === ' ') && peek(state) != end) read(state);
303-
304-
// read the end quote
305-
read(state);
306-
307-
matched = true;
308-
}
309-
}
310-
311-
if (!matched && paramTypes.custom && paramTypes.custom.length) {
312-
const custom = getCustomParam(state, paramTypes);
267+
const curCh: any = state.input[state.start];
268+
let nextChar = peek(state);
269+
let matched = false
313270

314-
if (custom) {
315-
read(state, custom.length);
316-
matched = true;
317-
}
318-
}
319-
320-
if (!matched && curCh !== '?' && nextChar !== null) { // not positional, panic
321-
return {
322-
type: 'parameter',
323-
value: 'unknown',
324-
start: state.start,
325-
end: state.end
326-
}
327-
}
271+
if (paramTypes.numbered && paramTypes.numbered.length && paramTypes.numbered.includes(curCh)) {
272+
const endIndex = state.input.slice(state.start).split('').findIndex((val) => isWhitespace(val));
273+
const maybeNumbers = state.input.slice(state.start + 1, endIndex > 0 ? state.start + endIndex : state.end + 1);
274+
if (nextChar !== null && !isNaN(Number(nextChar)) && /^\d+$/.test(maybeNumbers)) {
275+
let nextChar: Char = null;
276+
do {
277+
nextChar = read(state);
278+
} while (nextChar !== null && !isNaN(Number(nextChar)) && !isWhitespace(nextChar));
328279

329-
const value = state.input.slice(state.start, state.position + 1);
280+
if (nextChar !== null) unread(state);
281+
matched = true;
282+
}
283+
}
284+
285+
if (!matched && paramTypes.named && paramTypes.named.length && paramTypes.named.includes(curCh)) {
286+
if (!isQuotedIdentifier(nextChar, dialect)) {
287+
while (isAlphaNumeric(peek(state))) read(state);
288+
matched = true;
289+
}
290+
}
291+
292+
if (!matched && paramTypes.quoted && paramTypes.quoted.length && paramTypes.quoted.includes(curCh)) {
293+
if (isQuotedIdentifier(nextChar, dialect)) {
294+
const quoteChar = read(state) as string;
295+
// end when we reach the end quote
296+
while ((isAlphaNumeric(peek(state)) || peek(state) === ' ') && peek(state) != ENDTOKENS[quoteChar]) read(state);
297+
298+
// read the end quote
299+
read(state);
300+
301+
matched = true;
302+
}
303+
}
304+
305+
if (!matched && paramTypes.custom && paramTypes.custom.length) {
306+
const custom = getCustomParam(state, paramTypes);
307+
308+
if (custom) {
309+
read(state, custom.length);
310+
matched = true;
311+
}
312+
}
313+
314+
if (!matched && !paramTypes.positional) { // not positional, panic
330315
return {
331316
type: 'parameter',
332-
value,
317+
value: 'unknown',
333318
start: state.start,
334-
end: state.start + value.length - 1,
335-
};
319+
end: state.end
320+
}
321+
}
322+
323+
const value = state.input.slice(state.start, state.position + 1);
324+
return {
325+
type: 'parameter',
326+
value,
327+
start: state.start,
328+
end: state.start + value.length - 1,
329+
};
330+
}
331+
332+
function scanParameter(state: State, dialect: Dialect, paramTypes?: ParamTypes): Token {
333+
// user has defined wanted param types, so we only evaluate them
334+
if (paramTypes) {
335+
return scanCustomParameter(state, dialect, paramTypes);
336336
}
337337

338338
if (['mysql', 'generic', 'sqlite'].includes(dialect)) {
@@ -495,17 +495,17 @@ function isString(ch: Char, dialect: Dialect): boolean {
495495
}
496496

497497
function isCustomParam(state: State, paramTypes: ParamTypes): boolean | undefined {
498-
return paramTypes?.custom?.some(({ regex }) => {
499-
const reg = new RegExp(`(?:${regex})`, 'uy');
500-
return reg.test(state.input);
498+
return paramTypes?.custom?.some((regex) => {
499+
const reg = new RegExp(`^(?:${regex})`, 'uy');
500+
return reg.test(state.input.slice(state.start));
501501
})
502502
}
503503

504504
function isParameter(ch: Char, state: State, dialect: Dialect, paramTypes?: ParamTypes): boolean {
505505
if (paramTypes && ch !== null) {
506506
const curCh: any = ch;
507507
const nextChar = peek(state);
508-
if (paramTypes.positional && ch === '?' && nextChar === null) return true;
508+
if (paramTypes.positional && ch === '?' && (nextChar === null || isWhitespace(nextChar))) return true;
509509

510510
if (paramTypes.numbered && paramTypes.numbered.length && paramTypes.numbered.includes(curCh)) {
511511
if (nextChar !== null && !isNaN(Number(nextChar))) {

test/index.spec.ts

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { Dialect, getExecutionType, identify } from '../src/index';
22
import { expect } from 'chai';
3+
import { ParamTypes } from '../src/defines';
34

45
describe('identify', () => {
56
it('should throw error for invalid dialect', () => {
@@ -22,6 +23,29 @@ describe('identify', () => {
2223
]);
2324
});
2425

26+
it('should identify custom parameters', () => {
27+
const paramTypes: ParamTypes = {
28+
positional: true,
29+
numbered: ['$'],
30+
named: [':'],
31+
quoted: [':'],
32+
custom: [ '\\{[a-zA-Z0-9_]+\\}' ]
33+
};
34+
const query = `SELECT * FROM foo WHERE bar = ? AND baz = $1 AND fizz = :fizzz AND buzz = :"buzz buzz" AND foo2 = {fooo}`;
35+
36+
expect(identify(query, { dialect: 'psql', paramTypes })).to.eql([
37+
{
38+
start: 0,
39+
end: 104,
40+
text: query,
41+
type: 'SELECT',
42+
executionType: 'LISTING',
43+
parameters: ['?', '$1', ':fizzz', ':"buzz buzz"', '{fooo}'],
44+
tables: []
45+
}
46+
])
47+
})
48+
2549
it('should identify tables in simple for basic cases', () => {
2650
expect(
2751
identify('SELECT * FROM foo JOIN bar ON foo.id = bar.id', { identifyTables: true }),

test/tokenizer/index.spec.ts

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ describe('scan', () => {
562562

563563
it('should allow custom parameters for all dialects', () => {
564564
const paramTypes: ParamTypes = {
565-
custom: [{ regex: '\\{[a-zA-Z0-9_]+\\}' }]
565+
custom: [ '\\{[a-zA-Z0-9_]+\\}' ]
566566
};
567567

568568
const expected = {
@@ -585,6 +585,37 @@ describe('scan', () => {
585585
quoted: [':'],
586586
custom: []
587587
};
588+
589+
const expected = [
590+
{
591+
type: 'parameter',
592+
value: '?',
593+
start: 0,
594+
end: 0
595+
},
596+
{
597+
type: 'parameter',
598+
value: ':123',
599+
start: 0,
600+
end: 3
601+
},
602+
{
603+
type: 'parameter',
604+
value: ':123hello',
605+
start: 0,
606+
end: 8
607+
},
608+
{
609+
type: 'parameter',
610+
value: ':"named param"',
611+
start: 0,
612+
end: 13
613+
}
614+
];
615+
616+
expected.forEach((expected) => {
617+
expect(scanToken(initState(expected.value), 'mssql', paramTypes)).to.eql(expected);
618+
})
588619
})
589620
});
590621
});

0 commit comments

Comments
 (0)