Skip to content

Commit a2057db

Browse files
committed
chore(embedded/sql): add support for LEFT JOIN
Signed-off-by: Stefano Scafiti <[email protected]>
1 parent ae9af09 commit a2057db

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

embedded/sql/engine_test.go

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5986,7 +5986,7 @@ func TestNestedJoins(t *testing.T) {
59865986
require.NoError(t, err)
59875987
}
59885988

5989-
func TestLeftRightJoins(t *testing.T) {
5989+
func TestLeftJoins(t *testing.T) {
59905990
e := setupCommonTest(t)
59915991

59925992
_, _, err := e.Exec(
@@ -6055,10 +6055,28 @@ func TestLeftRightJoins(t *testing.T) {
60556055
)
60566056
require.NoError(t, err)
60576057

6058-
rows, err := e.queryAll(
6059-
context.Background(),
6060-
nil,
6061-
`SELECT
6058+
assertQueryShouldProduceResults(
6059+
t,
6060+
e,
6061+
`SELECT c.customer_id, c.customer_name, c.email, o.order_id, o.order_date
6062+
FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id
6063+
ORDER BY c.customer_id, o.order_date;`,
6064+
`
6065+
SELECT *
6066+
FROM (
6067+
VALUES
6068+
(1, 'Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP),
6069+
(1, 'Alice Johnson', '[email protected]', 103, '2024-11-03'::TIMESTAMP),
6070+
(2, 'Bob Smith', '[email protected]', 102, '2024-11-02'::TIMESTAMP),
6071+
(3, 'Charlie Brown', '[email protected]', NULL, NULL)
6072+
)`,
6073+
)
6074+
6075+
assertQueryShouldProduceResults(
6076+
t,
6077+
e,
6078+
`
6079+
SELECT
60626080
c.customer_name,
60636081
c.email,
60646082
o.order_id,
@@ -6073,10 +6091,16 @@ func TestLeftRightJoins(t *testing.T) {
60736091
LEFT JOIN orders o ON oi.order_id = o.order_id
60746092
LEFT JOIN customers c ON o.customer_id = c.customer_id
60756093
ORDER BY o.order_date, c.customer_name;`,
6076-
nil,
6094+
`
6095+
SELECT *
6096+
FROM (
6097+
VALUES
6098+
('Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP, 'Laptop', 2, 1200.00, 2400.00),
6099+
('Alice Johnson', '[email protected]', 101, '2024-11-01'::TIMESTAMP, 'Smartphone', 1, 800.00, 800.00),
6100+
('Bob Smith', '[email protected]', 102, '2024-11-02'::TIMESTAMP, 'Tablet', 3, 400.00, 1200.00),
6101+
('Alice Johnson', '[email protected]', 103, '2024-11-03'::TIMESTAMP, 'Smartphone', 2, 800.00, 1600.00)
6102+
)`,
60776103
)
6078-
require.NoError(t, err)
6079-
require.Len(t, rows, 4)
60806104
}
60816105

60826106
func TestReOpening(t *testing.T) {
@@ -9527,3 +9551,24 @@ func TestFunctions(t *testing.T) {
95279551
require.Equal(t, "OBJECT", rows[0].ValuesByPosition[0].RawValue().(string))
95289552
})
95299553
}
9554+
9555+
func assertQueryShouldProduceResults(t *testing.T, e *Engine, query, resultQuery string) {
9556+
queryReader, err := e.Query(context.Background(), nil, query, nil)
9557+
require.NoError(t, err)
9558+
defer queryReader.Close()
9559+
9560+
resultReader, err := e.Query(context.Background(), nil, resultQuery, nil)
9561+
require.NoError(t, err)
9562+
defer resultReader.Close()
9563+
9564+
for {
9565+
row, err := queryReader.Read(context.Background())
9566+
row1, err1 := resultReader.Read(context.Background())
9567+
require.Equal(t, err, err1)
9568+
9569+
if errors.Is(err, ErrNoMoreRows) {
9570+
break
9571+
}
9572+
require.Equal(t, row1.ValuesByPosition, row.ValuesByPosition)
9573+
}
9574+
}

embedded/sql/joint_row_reader.go

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,6 @@ func newJointRowReader(rowReader RowReader, joins []*JoinSpec) (*jointRowReader,
3838
return nil, ErrIllegalArguments
3939
}
4040

41-
// Sanity check: Ensure that no RIGHT JOINs are specified,
42-
// as we assume all RIGHT JOINs to be translated into equivalent LEFT JOINs.
4341
for _, jspec := range joins {
4442
if jspec.joinType == RightJoin {
4543
return nil, ErrUnsupportedJoinType

embedded/sql/stmt.go

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3409,44 +3409,9 @@ func (stmt *SelectStmt) Resolve(ctx context.Context, tx *SQLTx, params map[strin
34093409
rowReader = newLimitRowReader(rowReader, limit)
34103410
}
34113411
}
3412-
34133412
return rowReader, nil
34143413
}
34153414

3416-
// removeRightJoin converts all right joins in the SelectStmt to left joins by swapping the involved data sources.
3417-
func (stmt *SelectStmt) removeRightJoin() {
3418-
if len(stmt.joins) == 0 {
3419-
return
3420-
}
3421-
3422-
newJoins := make([]*JoinSpec, len(stmt.joins)+1)
3423-
3424-
start := 0
3425-
end := len(newJoins) - 1
3426-
3427-
for i := len(stmt.joins) - 1; i > 0; i-- {
3428-
jspec := stmt.joins[len(stmt.joins)-1-i]
3429-
3430-
if jspec.joinType == RightJoin {
3431-
newJoins[start] = jspec
3432-
newJoins[start].joinType = LeftJoin
3433-
start++
3434-
} else {
3435-
newJoins[end] = jspec
3436-
end--
3437-
}
3438-
}
3439-
3440-
newJoins[start] = &JoinSpec{ds: stmt.ds}
3441-
if start == 0 {
3442-
stmt.joins = newJoins[1:]
3443-
return
3444-
}
3445-
3446-
for i := start; i > 0; i-- {
3447-
}
3448-
}
3449-
34503415
func (stmt *SelectStmt) rearrangeOrdExps(groupByCols, orderByExps []*OrdExp) ([]*OrdExp, []*OrdExp) {
34513416
if len(groupByCols) > 0 && len(orderByExps) > 0 && !ordExpsHaveAggregations(orderByExps) {
34523417
if ordExpsHasPrefix(orderByExps, groupByCols, stmt.Alias()) {

0 commit comments

Comments
 (0)