Skip to content

Commit 6cc4f0b

Browse files
committed
Fix for incorrect subquery column type.
1 parent 38725aa commit 6cc4f0b

3 files changed

Lines changed: 171 additions & 8 deletions

File tree

internal/jet/alias.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ func (a *alias) fromImpl(subQuery SelectTable) Projection {
2727

2828
// This function is used to create dummy columns when exporting sub-query columns using subQuery.AllColumns()
2929
// In most case we don't care about type of the column, except when sub-query columns are used as SELECT_JSON projection.
30-
// We need to know type to encode value for json unmarshal. At the moment only bool, time and blob columns are of interest,
31-
// so we don't have to support every column type.
30+
// We need to know type to encode value for json unmarshal.
3231
func newDummyColumnForExpression(exp Expression, name string) ColumnExpression {
3332

3433
switch exp.(type) {
@@ -54,6 +53,41 @@ func newDummyColumnForExpression(exp Expression, name string) ColumnExpression {
5453
return IntervalColumn(name)
5554
case StringExpression:
5655
return StringColumn(name)
56+
57+
case Array[BoolExpression]:
58+
return ArrayColumn[BoolExpression](name)
59+
case Array[IntegerExpression]:
60+
return ArrayColumn[IntegerExpression](name)
61+
case Array[FloatExpression]:
62+
return ArrayColumn[FloatExpression](name)
63+
case Array[BlobExpression]:
64+
return ArrayColumn[BlobExpression](name)
65+
case Array[DateExpression]:
66+
return ArrayColumn[DateExpression](name)
67+
case Array[TimeExpression]:
68+
return ArrayColumn[TimeExpression](name)
69+
case Array[TimezExpression]:
70+
return ArrayColumn[TimezExpression](name)
71+
case Array[TimestampExpression]:
72+
return ArrayColumn[TimestampExpression](name)
73+
case Array[TimestampzExpression]:
74+
return ArrayColumn[TimestampzExpression](name)
75+
case Array[IntervalExpression]:
76+
return ArrayColumn[IntervalExpression](name)
77+
case Array[StringExpression]:
78+
return ArrayColumn[StringExpression](name)
79+
80+
case Range[Int4Expression], Range[Int8Expression]:
81+
return RangeColumn[IntegerExpression](name)
82+
case Range[NumericExpression]:
83+
return RangeColumn[NumericExpression](name)
84+
case Range[DateExpression]:
85+
return RangeColumn[DateExpression](name)
86+
case Range[TimestampExpression]:
87+
return RangeColumn[TimestampExpression](name)
88+
case Range[TimestampzExpression]:
89+
return RangeColumn[TimestampzExpression](name)
90+
5791
}
5892

5993
return StringColumn(name)

internal/jet/array_expression.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,10 @@ type arrayExpressionWrapper[E Expression] struct {
8080
}
8181

8282
func newArrayExpressionWrap[E Expression](expression Expression) Array[E] {
83-
arrayExpressionWrapper := arrayExpressionWrapper[E]{Expression: expression}
84-
arrayExpressionWrapper.arrayInterfaceImpl.parent = &arrayExpressionWrapper
85-
return &arrayExpressionWrapper
83+
arrayExpressionWrapper := &arrayExpressionWrapper[E]{Expression: expression}
84+
arrayExpressionWrapper.arrayInterfaceImpl.parent = arrayExpressionWrapper
85+
expression.setRoot(arrayExpressionWrapper)
86+
return arrayExpressionWrapper
8687
}
8788

8889
// ArrayExp is array expression wrapper around arbitrary expression.

tests/postgres/alltypes_test.go

Lines changed: 131 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package postgres
22

33
import (
44
"encoding/base64"
5+
"fmt"
56
"github.com/go-jet/jet/v2/internal/utils/ptr"
6-
"github.com/stretchr/testify/assert"
7-
"math"
8-
97
"github.com/go-jet/jet/v2/qrm"
108
"github.com/lib/pq"
9+
"github.com/stretchr/testify/assert"
10+
"math"
1111
"testing"
1212
"time"
1313

@@ -1728,6 +1728,134 @@ SELECT ROW($1::integer, $2::real, $3::text) AS "row",
17281728
})
17291729
}
17301730

1731+
func TestSubQueryAllExpTypes(t *testing.T) {
1732+
1733+
subquery := SELECT(
1734+
Bool(true).AS("bool"),
1735+
Int(11).AS("int"),
1736+
Text("doe").AS("text"),
1737+
Date(2000, 2, 2).AS("date"),
1738+
Time(11, 20, 40).AS("time"),
1739+
Timez(11, 20, 40, 200, "UTC").AS("timez"),
1740+
Timestamp(2030, 3, 4, 11, 20, 40).AS("timestamp"),
1741+
Timestampz(2023, 1, 2, 11, 20, 40, 200, "UTC").AS("timestampz"),
1742+
INTERVAL(100, HOUR).AS("interval"),
1743+
Bytea("bytes").AS("bytea"),
1744+
1745+
ARRAY(Bool(true)).AS("bool_arr"),
1746+
ARRAY(Int(11)).AS("int_arr"),
1747+
ARRAY(Text("doe")).AS("text_arr"),
1748+
ARRAY(Date(2000, 2, 2)).AS("date_arr"),
1749+
ARRAY(Time(11, 20, 40)).AS("time_arr"),
1750+
ARRAY(Timez(11, 20, 40, 200, "UTC")).AS("timez_arr"),
1751+
ARRAY(Timestamp(2030, 3, 4, 11, 20, 40)).AS("timestamp_arr"),
1752+
ARRAY(Timestampz(2023, 1, 2, 11, 20, 40, 200, "UTC")).AS("timestampz_arr"),
1753+
ARRAY(INTERVAL(100, HOUR)).AS("interval_arr"),
1754+
ARRAY(Bytea("bytes")).AS("bytea_arr"),
1755+
1756+
INT4_RANGE(Int(1), Int(200)).AS("int4_range"),
1757+
DATE_RANGE(Date(2000, 2, 2), Date(2000, 3, 3)).AS("date_range"),
1758+
NUM_RANGE(Float(33.22), Float(22.1)).AS("num_range"),
1759+
TS_RANGE(LOCALTIMESTAMP(), LOCALTIMESTAMP()).AS("ts_range"),
1760+
TSTZ_RANGE(NOW(), NOW()).AS("tstz_range"),
1761+
).AsTable("sub")
1762+
1763+
var result = "\n"
1764+
for _, projection := range subquery.AllColumns() {
1765+
result += fmt.Sprintf("Column type: %T\n", projection)
1766+
}
1767+
1768+
require.Equal(t, result, `
1769+
Column type: *jet.boolColumnImpl
1770+
Column type: *jet.integerColumnImpl
1771+
Column type: *jet.stringColumnImpl
1772+
Column type: *jet.dateColumnImpl
1773+
Column type: *jet.timeColumnImpl
1774+
Column type: *jet.timezColumnImpl
1775+
Column type: *jet.timestampColumnImpl
1776+
Column type: *jet.timestampzColumnImpl
1777+
Column type: *jet.intervalColumnImpl
1778+
Column type: *jet.blobColumnImpl
1779+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.BoolExpression]
1780+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.IntegerExpression]
1781+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.StringExpression]
1782+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.DateExpression]
1783+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimeExpression]
1784+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimezExpression]
1785+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimestampExpression]
1786+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimestampzExpression]
1787+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.IntervalExpression]
1788+
Column type: *jet.arrayColumnImpl[github.com/go-jet/jet/v2/internal/jet.BlobExpression]
1789+
Column type: *jet.rangeColumnImpl[github.com/go-jet/jet/v2/internal/jet.IntegerExpression]
1790+
Column type: *jet.rangeColumnImpl[github.com/go-jet/jet/v2/internal/jet.DateExpression]
1791+
Column type: *jet.rangeColumnImpl[github.com/go-jet/jet/v2/internal/jet.NumericExpression]
1792+
Column type: *jet.rangeColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimestampExpression]
1793+
Column type: *jet.rangeColumnImpl[github.com/go-jet/jet/v2/internal/jet.TimestampzExpression]
1794+
`)
1795+
1796+
stmt := SELECT(
1797+
subquery.AllColumns(),
1798+
).FROM(subquery)
1799+
1800+
testutils.AssertStatementSql(t, stmt, `
1801+
SELECT sub.bool AS "bool",
1802+
sub.int AS "int",
1803+
sub.text AS "text",
1804+
sub.date AS "date",
1805+
sub.time AS "time",
1806+
sub.timez AS "timez",
1807+
sub.timestamp AS "timestamp",
1808+
sub.timestampz AS "timestampz",
1809+
sub.interval AS "interval",
1810+
sub.bytea AS "bytea",
1811+
sub.bool_arr AS "bool_arr",
1812+
sub.int_arr AS "int_arr",
1813+
sub.text_arr AS "text_arr",
1814+
sub.date_arr AS "date_arr",
1815+
sub.time_arr AS "time_arr",
1816+
sub.timez_arr AS "timez_arr",
1817+
sub.timestamp_arr AS "timestamp_arr",
1818+
sub.timestampz_arr AS "timestampz_arr",
1819+
sub.interval_arr AS "interval_arr",
1820+
sub.bytea_arr AS "bytea_arr",
1821+
sub.int4_range AS "int4_range",
1822+
sub.date_range AS "date_range",
1823+
sub.num_range AS "num_range",
1824+
sub.ts_range AS "ts_range",
1825+
sub.tstz_range AS "tstz_range"
1826+
FROM (
1827+
SELECT $1::boolean AS "bool",
1828+
$2 AS "int",
1829+
$3::text AS "text",
1830+
$4::date AS "date",
1831+
$5::time without time zone AS "time",
1832+
$6::time with time zone AS "timez",
1833+
$7::timestamp without time zone AS "timestamp",
1834+
$8::timestamp with time zone AS "timestampz",
1835+
INTERVAL '100 HOUR' AS "interval",
1836+
$9::bytea AS "bytea",
1837+
ARRAY[$10::boolean] AS "bool_arr",
1838+
ARRAY[$11] AS "int_arr",
1839+
ARRAY[$12::text] AS "text_arr",
1840+
ARRAY[$13::date] AS "date_arr",
1841+
ARRAY[$14::time without time zone] AS "time_arr",
1842+
ARRAY[$15::time with time zone] AS "timez_arr",
1843+
ARRAY[$16::timestamp without time zone] AS "timestamp_arr",
1844+
ARRAY[$17::timestamp with time zone] AS "timestampz_arr",
1845+
ARRAY[INTERVAL '100 HOUR'] AS "interval_arr",
1846+
ARRAY[$18::bytea] AS "bytea_arr",
1847+
int4range($19, $20) AS "int4_range",
1848+
daterange($21::date, $22::date) AS "date_range",
1849+
numrange($23, $24) AS "num_range",
1850+
tsrange(LOCALTIMESTAMP, LOCALTIMESTAMP) AS "ts_range",
1851+
tstzrange(NOW(), NOW()) AS "tstz_range"
1852+
) AS sub;
1853+
`)
1854+
1855+
_, err := stmt.Exec(db)
1856+
require.NoError(t, err)
1857+
}
1858+
17311859
func TestAllTypesSubQueryFrom(t *testing.T) {
17321860
subQuery := SELECT(
17331861
AllTypes.Boolean,

0 commit comments

Comments
 (0)