Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions edb/pgsql/resolver/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,8 @@ def resolve_LockingClause(
'jsonb_object_agg_unique',
'json_object_agg_unique_strict',
'jsonb_object_agg_unique_strict',
'date_trunc',
'date_part',
}


Expand Down
15 changes: 9 additions & 6 deletions edb/server/protocol/pg_ext.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1932,12 +1932,15 @@ cdef bytes remap_arguments(
if arg_offset_external < offset:
buf.write_bytes(data[arg_offset_external:offset])

# write non-external args
extracted_consts = list(source.variables().values())
for (e, param) in enumerate(params[param_count_external:]):
if isinstance(param, dbstate.SQLParamExtractedConst):
buf.write_len_prefixed_bytes(extracted_consts[e])
elif isinstance(param, dbstate.SQLParamGlobal):
# Write extracted consts from extra_blobs. The blob already contains
# all len-prefixed values in param order (matching args_ser.pyx pattern).
extra_blobs = source.extra_blobs()
if extra_blobs:
buf.write_bytes(extra_blobs[0])

# Write globals
for param in params[param_count_external:]:
if isinstance(param, dbstate.SQLParamGlobal):
name = param.global_name
if param.is_permission:
buf.write_int32(1)
Expand Down
25 changes: 25 additions & 0 deletions tests/test_sql_dml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2545,3 +2545,28 @@ async def test_sql_dml_03(self):
'''SELECT COUNT(*) FROM "Child.tags"'''
)
self.assertEqual(res, [[1]])

async def test_sql_dml_date_trunc_group_by(self):
# Regression test for DATE_TRUNC with GROUP BY
await self.scon.execute(
'''
INSERT INTO "Post" (title, created_at) VALUES
('Post 1', '2024-01-15 10:00:00+00'),
('Post 2', '2024-01-20 15:30:00+00'),
('Post 3', '2024-02-10 09:00:00+00')
'''
)

res = await self.squery_values(
'''
SELECT
DATE_TRUNC('month', p.created_at) as log_date,
COUNT(*) as log_count
FROM "Post" p
GROUP BY DATE_TRUNC('month', p.created_at)
ORDER BY log_date
'''
)
self.assertEqual(len(res), 2)
self.assertEqual(res[0][1], 2) # January count
self.assertEqual(res[1][1], 1) # February count
Loading