Skip to content

Commit 967c31e

Browse files
committed
chore(embedded/sql): Implement CASE statement
Signed-off-by: Stefano Scafiti <[email protected]>
1 parent e77545f commit 967c31e

File tree

8 files changed

+945
-298
lines changed

8 files changed

+945
-298
lines changed

embedded/sql/engine_test.go

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2885,6 +2885,164 @@ func TestQuery(t *testing.T) {
28852885
}
28862886
})
28872887

2888+
t.Run("query with case when then", func(t *testing.T) {
2889+
_, _, err := engine.Exec(
2890+
context.Background(),
2891+
nil,
2892+
`CREATE TABLE employees (
2893+
employee_id INTEGER AUTO_INCREMENT,
2894+
first_name VARCHAR[50],
2895+
last_name VARCHAR[50],
2896+
department VARCHAR[50],
2897+
salary INTEGER,
2898+
hire_date TIMESTAMP,
2899+
job_title VARCHAR[50],
2900+
2901+
PRIMARY KEY employee_id
2902+
);`,
2903+
nil,
2904+
)
2905+
require.NoError(t, err)
2906+
2907+
n := 100
2908+
for i := 0; i < n; i++ {
2909+
_, _, err := engine.Exec(
2910+
context.Background(),
2911+
nil,
2912+
`INSERT INTO employees(first_name, last_name, department, salary, job_title)
2913+
VALUES (@first_name, @last_name, @department, @salary, @job_title)
2914+
`,
2915+
map[string]interface{}{
2916+
"first_name": fmt.Sprintf("name%d", i),
2917+
"last_name": fmt.Sprintf("surname%d", i),
2918+
"department": []string{"sales", "manager", "engineering"}[rand.Intn(3)],
2919+
"salary": []int64{20, 40, 50, 80, 100}[rand.Intn(5)] * 1000,
2920+
"job_title": []string{"manager", "senior engineer", "executive"}[rand.Intn(3)],
2921+
},
2922+
)
2923+
require.NoError(t, err)
2924+
}
2925+
2926+
_, err = engine.queryAll(
2927+
context.Background(),
2928+
nil,
2929+
"SELECT CASE WHEN salary THEN 0 END FROM employees",
2930+
nil,
2931+
)
2932+
require.ErrorIs(t, err, ErrInvalidTypes)
2933+
2934+
rows, err := engine.queryAll(
2935+
context.Background(),
2936+
nil,
2937+
`SELECT
2938+
employee_id,
2939+
first_name,
2940+
last_name,
2941+
salary,
2942+
CASE
2943+
WHEN salary < 50000 THEN @low
2944+
WHEN salary >= 50000 AND salary <= 100000 THEN @medium
2945+
ELSE @high
2946+
END AS salary_category
2947+
FROM employees;`,
2948+
map[string]interface{}{
2949+
"low": "Low",
2950+
"medium": "Medium",
2951+
"high": "High",
2952+
},
2953+
)
2954+
require.NoError(t, err)
2955+
require.Len(t, rows, n)
2956+
2957+
for _, row := range rows {
2958+
salary := row.ValuesByPosition[3].RawValue().(int64)
2959+
category, _ := row.ValuesByPosition[4].RawValue().(string)
2960+
2961+
expectedCategory := "High"
2962+
if salary < 50000 {
2963+
expectedCategory = "Low"
2964+
} else if salary >= 50000 && salary <= 100000 {
2965+
expectedCategory = "Medium"
2966+
}
2967+
require.Equal(t, expectedCategory, category)
2968+
}
2969+
2970+
rows, err = engine.queryAll(
2971+
context.Background(),
2972+
nil,
2973+
`SELECT
2974+
department,
2975+
job_title,
2976+
CASE
2977+
WHEN department = 'sales' THEN
2978+
CASE
2979+
WHEN job_title = 'manager' THEN '20% Bonus'
2980+
ELSE '10% Bonus'
2981+
END
2982+
WHEN department = 'engineering' THEN
2983+
CASE
2984+
WHEN job_title = 'senior engineer' THEN '15% Bonus'
2985+
ELSE '5% Bonus'
2986+
END
2987+
ELSE
2988+
CASE
2989+
WHEN job_title = 'executive' THEN '12% Bonus'
2990+
ELSE 'No Bonus'
2991+
END
2992+
END AS bonus
2993+
FROM employees;`,
2994+
nil,
2995+
)
2996+
require.NoError(t, err)
2997+
require.Len(t, rows, n)
2998+
2999+
for _, row := range rows {
3000+
department := row.ValuesByPosition[0].RawValue().(string)
3001+
job, _ := row.ValuesByPosition[1].RawValue().(string)
3002+
bonus, _ := row.ValuesByPosition[2].RawValue().(string)
3003+
3004+
var expectedBonus string
3005+
switch department {
3006+
case "sales":
3007+
if job == "manager" {
3008+
expectedBonus = "20% Bonus"
3009+
} else {
3010+
expectedBonus = "10% Bonus"
3011+
}
3012+
case "engineering":
3013+
if job == "senior engineer" {
3014+
expectedBonus = "15% Bonus"
3015+
} else {
3016+
expectedBonus = "5% Bonus"
3017+
}
3018+
default:
3019+
if job == "executive" {
3020+
expectedBonus = "12% Bonus"
3021+
} else {
3022+
expectedBonus = "No Bonus"
3023+
}
3024+
}
3025+
require.Equal(t, expectedBonus, bonus)
3026+
}
3027+
3028+
rows, err = engine.queryAll(
3029+
context.Background(),
3030+
nil,
3031+
`SELECT
3032+
CASE
3033+
WHEN department = 'sales' THEN 'Sales Team'
3034+
END AS department
3035+
FROM employees
3036+
WHERE department != 'sales'
3037+
LIMIT 1
3038+
;`,
3039+
nil,
3040+
)
3041+
require.NoError(t, err)
3042+
require.Len(t, rows, 1)
3043+
require.Nil(t, rows[0].ValuesByPosition[0].RawValue())
3044+
})
3045+
28883046
t.Run("invalid queries", func(t *testing.T) {
28893047
r, err = engine.Query(context.Background(), nil, "INVALID QUERY", nil)
28903048
require.ErrorIs(t, err, ErrParsingError)

embedded/sql/parser.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,11 @@ var reservedWords = map[string]int{
110110
"PRIVILEGES": PRIVILEGES,
111111
"CHECK": CHECK,
112112
"CONSTRAINT": CONSTRAINT,
113+
"CASE": CASE,
114+
"WHEN": WHEN,
115+
"THEN": THEN,
116+
"ELSE": ELSE,
117+
"END": END,
113118
}
114119

115120
var joinTypes = map[string]JoinType{

embedded/sql/parser_test.go

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1386,7 +1386,7 @@ func TestAggFnStmt(t *testing.T) {
13861386
}
13871387
}
13881388

1389-
func TestExpressions(t *testing.T) {
1389+
func TestParseExp(t *testing.T) {
13901390
testCases := []struct {
13911391
input string
13921392
expectedOutput []SQLStmt
@@ -1674,6 +1674,98 @@ func TestExpressions(t *testing.T) {
16741674
}},
16751675
expectedError: nil,
16761676
},
1677+
{
1678+
input: "SELECT CASE WHEN is_deleted OR is_expired THEN 1 END AS is_deleted_or_expired FROM my_table",
1679+
expectedOutput: []SQLStmt{
1680+
&SelectStmt{
1681+
ds: &tableRef{table: "my_table"},
1682+
targets: []TargetEntry{
1683+
{
1684+
Exp: &CaseWhenExp{
1685+
whenThen: []whenThenClause{
1686+
{
1687+
when: &BinBoolExp{
1688+
op: OR,
1689+
left: &ColSelector{col: "is_deleted"},
1690+
right: &ColSelector{col: "is_expired"},
1691+
},
1692+
then: &Integer{1},
1693+
},
1694+
},
1695+
},
1696+
As: "is_deleted_or_expired",
1697+
},
1698+
},
1699+
},
1700+
},
1701+
},
1702+
{
1703+
input: "SELECT CASE WHEN is_active THEN 1 ELSE 2 END FROM my_table",
1704+
expectedOutput: []SQLStmt{
1705+
&SelectStmt{
1706+
ds: &tableRef{table: "my_table"},
1707+
targets: []TargetEntry{
1708+
{
1709+
Exp: &CaseWhenExp{
1710+
whenThen: []whenThenClause{
1711+
{
1712+
when: &ColSelector{col: "is_active"},
1713+
then: &Integer{1},
1714+
},
1715+
},
1716+
elseExp: &Integer{2},
1717+
},
1718+
},
1719+
},
1720+
},
1721+
},
1722+
},
1723+
{
1724+
input: `
1725+
SELECT product_name,
1726+
CASE
1727+
WHEN stock < 10 THEN 'Low stock'
1728+
WHEN stock >= 10 AND stock <= 50 THEN 'Medium stock'
1729+
WHEN stock > 50 THEN 'High stock'
1730+
ELSE 'Out of stock'
1731+
END AS stock_status
1732+
FROM products
1733+
`,
1734+
expectedOutput: []SQLStmt{
1735+
&SelectStmt{
1736+
ds: &tableRef{table: "products"},
1737+
targets: []TargetEntry{
1738+
{
1739+
Exp: &ColSelector{col: "product_name"},
1740+
},
1741+
{
1742+
Exp: &CaseWhenExp{
1743+
whenThen: []whenThenClause{
1744+
{
1745+
when: &CmpBoolExp{op: LT, left: &ColSelector{col: "stock"}, right: &Integer{10}},
1746+
then: &Varchar{"Low stock"},
1747+
},
1748+
{
1749+
when: &BinBoolExp{
1750+
op: AND,
1751+
left: &CmpBoolExp{op: GE, left: &ColSelector{col: "stock"}, right: &Integer{10}},
1752+
right: &CmpBoolExp{op: LE, left: &ColSelector{col: "stock"}, right: &Integer{50}},
1753+
},
1754+
then: &Varchar{"Medium stock"},
1755+
},
1756+
{
1757+
when: &CmpBoolExp{op: GT, left: &ColSelector{col: "stock"}, right: &Integer{50}},
1758+
then: &Varchar{"High stock"},
1759+
},
1760+
},
1761+
elseExp: &Varchar{"Out of stock"},
1762+
},
1763+
As: "stock_status",
1764+
},
1765+
},
1766+
},
1767+
},
1768+
},
16771769
}
16781770

16791771
for i, tc := range testCases {
@@ -1897,6 +1989,9 @@ func TestExprString(t *testing.T) {
18971989
"((col1 AND (col2 < 10)) OR (@param = 3 AND (col4 = TRUE))) AND NOT (col5 = 'value' OR (2 + 2 != 4))",
18981990
"CAST (func_call(1, 'two', 2.5) AS TIMESTAMP)",
18991991
"col IN (TRUE, 1, 'test', 1.5)",
1992+
"CASE WHEN in_stock THEN 'In Stock' END",
1993+
"CASE WHEN 1 > 0 THEN 1 ELSE 0 END",
1994+
"CASE WHEN is_active THEN 'active' WHEN is_expired THEN 'expired' ELSE 'active' END",
19001995
}
19011996

19021997
for i, e := range exps {

embedded/sql/proj_row_reader.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,12 @@ func (pr *projectedRowReader) Read(ctx context.Context) (*Row, error) {
190190
}
191191

192192
for i, t := range pr.targets {
193-
v, err := t.Exp.reduce(pr.Tx(), row, pr.rowReader.TableAlias())
193+
e, err := t.Exp.substitute(pr.Parameters())
194+
if err != nil {
195+
return nil, fmt.Errorf("%w: when evaluating WHERE clause", err)
196+
}
197+
198+
v, err := e.reduce(pr.Tx(), row, pr.rowReader.TableAlias())
194199
if err != nil {
195200
return nil, err
196201
}

embedded/sql/sql_grammar.y

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,14 @@ func setResult(l yyLexer, stmts []SQLStmt) {
7272
permission Permission
7373
sqlPrivilege SQLPrivilege
7474
sqlPrivileges []SQLPrivilege
75+
whenThenClauses []whenThenClause
7576
}
7677

7778
%token CREATE DROP USE DATABASE USER WITH PASSWORD READ READWRITE ADMIN SNAPSHOT HISTORY SINCE AFTER BEFORE UNTIL TX OF TIMESTAMP
7879
%token TABLE UNIQUE INDEX ON ALTER ADD RENAME TO COLUMN CONSTRAINT PRIMARY KEY CHECK GRANT REVOKE GRANTS FOR PRIVILEGES
7980
%token BEGIN TRANSACTION COMMIT ROLLBACK
8081
%token INSERT UPSERT INTO VALUES DELETE UPDATE SET CONFLICT DO NOTHING RETURNING
81-
%token SELECT DISTINCT FROM JOIN HAVING WHERE GROUP BY LIMIT OFFSET ORDER ASC DESC AS UNION ALL
82+
%token SELECT DISTINCT FROM JOIN HAVING WHERE GROUP BY LIMIT OFFSET ORDER ASC DESC AS UNION ALL CASE WHEN THEN ELSE END
8283
%token NOT LIKE IF EXISTS IN IS
8384
%token AUTO_INCREMENT NULL CAST SCAST
8485
%token SHOW DATABASES TABLES USERS
@@ -135,10 +136,10 @@ func setResult(l yyLexer, stmts []SQLStmt) {
135136
%type <join> join
136137
%type <joinType> opt_join_type
137138
%type <checks> opt_checks
138-
%type <exp> exp opt_where opt_having boundexp
139+
%type <exp> exp opt_where opt_having boundexp opt_else when_then_else
139140
%type <binExp> binExp
140141
%type <cols> opt_groupby
141-
%type <exp> opt_limit opt_offset
142+
%type <exp> opt_limit opt_offset case_when_exp
142143
%type <targets> opt_targets targets
143144
%type <integer> opt_max_len
144145
%type <id> opt_as
@@ -152,6 +153,7 @@ func setResult(l yyLexer, stmts []SQLStmt) {
152153
%type <permission> permission
153154
%type <sqlPrivilege> sqlPrivilege
154155
%type <sqlPrivileges> sqlPrivileges
156+
%type <whenThenClauses> when_then_clauses
155157

156158
%start sql
157159

@@ -1095,6 +1097,51 @@ exp:
10951097
{
10961098
$$ = &InListExp{val: $1, notIn: $2, values: $5}
10971099
}
1100+
|
1101+
case_when_exp
1102+
{
1103+
$$ = $1
1104+
}
1105+
1106+
case_when_exp:
1107+
CASE when_then_else END
1108+
{
1109+
$$ = $2
1110+
}
1111+
;
1112+
1113+
when_then_else:
1114+
when_then_clauses opt_else
1115+
{
1116+
$$ = &CaseWhenExp{
1117+
whenThen: $1,
1118+
elseExp: $2,
1119+
}
1120+
}
1121+
;
1122+
1123+
when_then_clauses:
1124+
WHEN exp THEN exp
1125+
{
1126+
$$ = []whenThenClause{{when: $2, then: $4}}
1127+
}
1128+
|
1129+
when_then_clauses WHEN exp THEN exp
1130+
{
1131+
$$ = append($1, whenThenClause{when: $3, then: $5})
1132+
}
1133+
;
1134+
1135+
opt_else:
1136+
{
1137+
$$ = nil
1138+
}
1139+
|
1140+
ELSE exp
1141+
{
1142+
$$ = $2
1143+
}
1144+
;
10981145

10991146
boundexp:
11001147
selector

0 commit comments

Comments
 (0)