Skip to content

Commit 9016cb3

Browse files
committed
Add createAggregateFunction method
1 parent 4324480 commit 9016cb3

8 files changed

+223
-32
lines changed

docs/.vitepress/config.ts

+4
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ export default defineConfig({
6969
text: 'createScalarFunction',
7070
link: '/api/createscalarfunction',
7171
},
72+
{
73+
text: 'createAggregateFunction',
74+
link: '/api/createaggregatefunction',
75+
},
7276
{
7377
text: 'destroy',
7478
link: '/api/destroy',

docs/api/createaggregatefunction.md

+52
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# createAggregateFunction
2+
3+
Create a SQL function that can be called from queries to combine multiple rows into a single result row.
4+
5+
## Usage
6+
7+
Access or destructure `createAggregateFunction` from the `SQLocal` client.
8+
9+
```javascript
10+
import { SQLocal } from 'sqlocal';
11+
12+
const { createAggregateFunction } = new SQLocal('database.sqlite3');
13+
```
14+
15+
<!-- @include: ../.partials/initialization-note.md -->
16+
17+
This method takes a string to name a custom SQL function as its first argument and an object containing two functions (`step` and `final`) as its second argument. After running `createAggregateFunction`, the aggregate function that you defined can be called from subsequent SQL queries. Arguments passed to the function in the SQL query will be passed to the JavaScript `step` function. The `step` function will run for every row in the SQL query. After each row is processed, the `final` function will run, and its return value will be passed back to SQLite to use to complete the query.
18+
19+
This can be used to combine rows together in a query based on some custom logic. For example, the below aggregate function can be used to find the most common value for a column, such as the most common category used in a table of tasks.
20+
21+
```javascript
22+
const values = new Map();
23+
24+
await createAggregateFunction('mostCommon', {
25+
step: (value) => {
26+
const valueCount = values.get(value) ?? 0;
27+
values.set(value, valueCount + 1);
28+
},
29+
final: () => {
30+
const valueEntries = Array.from(values.entries());
31+
const sortedEntries = valueEntries.sort((a, b) => b[1] - a[1]);
32+
const mostCommonValue = sortedEntries[0][0];
33+
values.clear();
34+
return mostCommonValue;
35+
},
36+
});
37+
38+
await sql`SELECT mostCommon(category) AS mostCommonCategory FROM tasks`;
39+
```
40+
41+
Aggregate functions can also be used in a query's HAVING clause to filter groups of rows. Here, we use the `mostCommon` function that we created in the previous example to find which days of the week have "Cleaning" as the most common category of task.
42+
43+
```javascript
44+
await sql`
45+
SELECT dayOfWeek
46+
FROM tasks
47+
GROUP BY dayOfWeek
48+
HAVING mostCommon(category) = 'Cleaning'
49+
`;
50+
```
51+
52+
<!-- @include: ../.partials/functions-note.md -->

src/client.ts

+31-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import type {
1111
StatementInput,
1212
Transaction,
1313
DatabasePath,
14+
AggregateUserFunction,
1415
} from './types.js';
1516
import type {
1617
BatchMessage,
@@ -391,9 +392,12 @@ export class SQLocal {
391392
func: ScalarUserFunction['func']
392393
): Promise<void> => {
393394
const key = `_sqlocal_func_${funcName}`;
395+
const attachFunction = () => {
396+
this.proxy[key] = func;
397+
};
394398

395399
if (this.proxy === globalThis) {
396-
this.proxy[key] = func;
400+
attachFunction();
397401
}
398402

399403
await this.createQuery({
@@ -403,7 +407,32 @@ export class SQLocal {
403407
});
404408

405409
if (this.proxy !== globalThis) {
406-
this.proxy[key] = func;
410+
attachFunction();
411+
}
412+
};
413+
414+
createAggregateFunction = async (
415+
funcName: string,
416+
func: AggregateUserFunction['func']
417+
): Promise<void> => {
418+
const key = `_sqlocal_func_${funcName}`;
419+
const attachFunction = () => {
420+
this.proxy[`${key}_step`] = func.step;
421+
this.proxy[`${key}_final`] = func.final;
422+
};
423+
424+
if (this.proxy === globalThis) {
425+
attachFunction();
426+
}
427+
428+
await this.createQuery({
429+
type: 'function',
430+
functionName: funcName,
431+
functionType: 'aggregate',
432+
});
433+
434+
if (this.proxy !== globalThis) {
435+
attachFunction();
407436
}
408437
};
409438

src/drivers/sqlite-memory-driver.ts

+18-5
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,24 @@ export class SQLiteMemoryDriver implements SQLocalDriver {
8787
async createFunction(fn: UserFunction): Promise<void> {
8888
if (!this.db) throw new Error('Driver not initialized');
8989

90-
this.db.createFunction({
91-
name: fn.name,
92-
xFunc: (_: number, ...args: any[]) => fn.func(...args),
93-
arity: -1,
94-
});
90+
switch (fn.type) {
91+
case 'callback':
92+
case 'scalar':
93+
this.db.createFunction({
94+
name: fn.name,
95+
xFunc: (_: number, ...args: any[]) => fn.func(...args),
96+
arity: -1,
97+
});
98+
break;
99+
case 'aggregate':
100+
this.db.createFunction({
101+
name: fn.name,
102+
xStep: (_: number, ...args: any[]) => fn.func.step(...args),
103+
xFinal: (_: number, ...args: any[]) => fn.func.final(...args),
104+
arity: -1,
105+
});
106+
break;
107+
}
95108
}
96109

97110
async import(

src/processor.ts

+32-19
Original file line numberDiff line numberDiff line change
@@ -266,38 +266,51 @@ export class SQLocalProcessor {
266266
protected createUserFunction = async (
267267
message: FunctionMessage
268268
): Promise<void> => {
269-
const { functionName, functionType, queryKey } = message;
270-
let func;
269+
const { functionName: name, functionType: type, queryKey } = message;
270+
let fn: UserFunction;
271271

272-
if (this.userFunctions.has(functionName)) {
272+
if (this.userFunctions.has(name)) {
273273
this.emitMessage({
274274
type: 'error',
275275
error: new Error(
276-
`A user-defined function with the name "${functionName}" has already been created for this SQLocal instance.`
276+
`A user-defined function with the name "${name}" has already been created for this SQLocal instance.`
277277
),
278278
queryKey,
279279
});
280280
return;
281281
}
282282

283-
if (functionType === 'callback') {
284-
func = (...args: any[]) => {
285-
this.emitMessage({
286-
type: 'callback',
287-
name: functionName,
288-
args: args,
289-
});
290-
};
291-
} else {
292-
func = this.proxy[`_sqlocal_func_${functionName}`];
283+
switch (type) {
284+
case 'callback':
285+
fn = {
286+
type,
287+
name,
288+
func: (...args: any[]) => {
289+
this.emitMessage({ type: 'callback', name, args });
290+
},
291+
};
292+
break;
293+
case 'scalar':
294+
fn = {
295+
type,
296+
name,
297+
func: this.proxy[`_sqlocal_func_${name}`],
298+
};
299+
break;
300+
case 'aggregate':
301+
fn = {
302+
type,
303+
name,
304+
func: {
305+
step: this.proxy[`_sqlocal_func_${name}_step`],
306+
final: this.proxy[`_sqlocal_func_${name}_final`],
307+
},
308+
};
309+
break;
293310
}
294311

295312
try {
296-
await this.initUserFunction({
297-
type: functionType,
298-
name: functionName,
299-
func,
300-
});
313+
await this.initUserFunction(fn);
301314
this.emitMessage({
302315
type: 'success',
303316
queryKey,

src/types.ts

+12-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,10 @@ export type DatabaseInfo = {
123123

124124
// User functions
125125

126-
export type UserFunction = CallbackUserFunction | ScalarUserFunction;
126+
export type UserFunction =
127+
| CallbackUserFunction
128+
| ScalarUserFunction
129+
| AggregateUserFunction;
127130
export type CallbackUserFunction = {
128131
type: 'callback';
129132
name: string;
@@ -134,3 +137,11 @@ export type ScalarUserFunction = {
134137
name: string;
135138
func: (...args: any[]) => any;
136139
};
140+
export type AggregateUserFunction = {
141+
type: 'aggregate';
142+
name: string;
143+
func: {
144+
step: (...args: any[]) => void;
145+
final: (...args: any[]) => any;
146+
};
147+
};
+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { afterEach, beforeAll, beforeEach, describe, expect, it } from 'vitest';
2+
import { SQLocal } from '../src/index.js';
3+
4+
describe.each([
5+
{ type: 'opfs', path: 'create-aggregate-function-test.sqlite3' },
6+
{ type: 'memory', path: ':memory:' },
7+
{ type: 'local', path: ':localStorage:' },
8+
{ type: 'session', path: ':sessionStorage:' },
9+
])('createAggregateFunction ($type)', ({ path }) => {
10+
const { sql, createAggregateFunction } = new SQLocal(path);
11+
12+
beforeAll(async () => {
13+
const values = new Map<unknown, number>();
14+
15+
await createAggregateFunction('mostCommon', {
16+
step: (value: unknown) => {
17+
const valueCount = values.get(value) ?? 0;
18+
values.set(value, valueCount + 1);
19+
},
20+
final: () => {
21+
const valueEntries = Array.from(values.entries());
22+
const sortedEntries = valueEntries.sort((a, b) => b[1] - a[1]);
23+
const mostCommonValue = sortedEntries[0][0];
24+
values.clear();
25+
return mostCommonValue;
26+
},
27+
});
28+
});
29+
30+
beforeEach(async () => {
31+
await sql`CREATE TABLE nums (num REAL NOT NULL)`;
32+
});
33+
34+
afterEach(async () => {
35+
await sql`DROP TABLE nums`;
36+
});
37+
38+
it('should create and use aggregate function in SELECT clause', async () => {
39+
await sql`INSERT INTO nums (num) VALUES (0), (3), (2), (7), (3), (1), (5), (3), (3), (2)`;
40+
41+
const results = await sql`SELECT mostCommon(num) AS mostCommon FROM nums`;
42+
43+
expect(results).toEqual([{ mostCommon: 3 }]);
44+
});
45+
46+
it('should create and use aggregate function in HAVING clause', async () => {
47+
await sql`INSERT INTO nums (num) VALUES (1), (2), (2), (2), (4), (5), (5), (6)`;
48+
49+
const results = await sql`
50+
SELECT mod(num, 2) AS isOdd
51+
FROM nums
52+
GROUP BY isOdd
53+
HAVING mostCommon(num) = 5
54+
`;
55+
56+
expect(results).toEqual([{ isOdd: 1 }]);
57+
});
58+
59+
it('should not replace an existing implementation', async () => {
60+
const createBadFn = async () => {
61+
await createAggregateFunction('mostCommon', {
62+
step: () => {},
63+
final: () => 0,
64+
});
65+
};
66+
67+
await expect(createBadFn).rejects.toThrowError();
68+
});
69+
});

test/create-scalar-function.test.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ describe.each([
1717
await sql`DROP TABLE nums`;
1818
});
1919

20-
it('should create and use scalar function in columns clause', async () => {
20+
it('should create and use scalar function in SELECT clause', async () => {
2121
await createScalarFunction('double', (num: number) => num * 2);
2222

2323
const createBadFn = async () =>
@@ -26,7 +26,7 @@ describe.each([
2626

2727
await sql`INSERT INTO nums (num) VALUES (0), (2), (3.5), (-11.11)`;
2828

29-
const results = await sql`SELECT num, double(num) as doubled FROM nums`;
29+
const results = await sql`SELECT num, double(num) AS doubled FROM nums`;
3030

3131
expect(results).toEqual([
3232
{ num: 0, doubled: 0 },
@@ -36,7 +36,7 @@ describe.each([
3636
]);
3737
});
3838

39-
it('should create and use scalar function in where clause', async () => {
39+
it('should create and use scalar function in WHERE clause', async () => {
4040
await createScalarFunction('isEven', (num: number) => num % 2 === 0);
4141

4242
await sql`INSERT INTO nums (num) VALUES (2), (3), (4), (5), (6)`;
@@ -74,8 +74,8 @@ describe.each([
7474
await db1.createScalarFunction('addTax', (num: number) => num * 1.06);
7575
await db2.createScalarFunction('addTax', (num: number) => num * 1.07);
7676

77-
const [result1] = await db1.sql`SELECT addTax(2) as withTax`;
78-
const [result2] = await db2.sql`SELECT addTax(2) as withTax`;
77+
const [result1] = await db1.sql`SELECT addTax(2) AS withTax`;
78+
const [result2] = await db2.sql`SELECT addTax(2) AS withTax`;
7979

8080
expect(result1.withTax).toBe(2.12);
8181
expect(result2.withTax).toBe(2.14);

0 commit comments

Comments
 (0)