1- import { defaultTypeExprMapping , fmap , normalizeIndent } from "@ts-safeql/shared" ;
1+ import { fmap , normalizeIndent } from "@ts-safeql/shared" ;
22import * as LibPgQueryAST from "@ts-safeql/sql-ast" ;
33import {
44 isColumnStarRef ,
@@ -16,6 +16,7 @@ type ASTDescriptionOptions = {
1616 parsed : LibPgQueryAST . ParseResult ;
1717 relations : FlattenedRelationWithJoins [ ] ;
1818 typesMap : Map < string , { override : boolean ; value : string } > ;
19+ typeExprMap : Map < string , Map < string , Map < string , string > > > ;
1920 overridenColumnTypesMap : Map < string , Map < string , string > > ;
2021 nonNullableColumns : Set < string > ;
2122 pgColsBySchemaAndTableName : Map < string , Map < string , PgColRow [ ] > > ;
@@ -274,61 +275,119 @@ function getDescribedAExpr({
274275
275276 if ( column === undefined ) return null ;
276277
277- if ( column . type . kind === "array" ) {
278- return { value : "array" , nullable : false } ;
279- }
278+ const getFromType = (
279+ type : ASTDescribedColumnType ,
280+ ) : { value : string ; array : boolean ; nullable : boolean } | null => {
281+ switch ( true ) {
282+ case type . kind === "type" :
283+ return { value : type . base ?? type . type , array : false , nullable : false } ;
280284
281- if ( column . type . kind === "type" ) {
282- return { value : column . type . base ?? column . type . type , nullable : false } ;
283- }
285+ case type . kind === "literal" && type . base . kind === "type" :
286+ return { value : type . base . type , array : false , nullable : false } ;
284287
285- if ( column . type . kind === "literal" && column . type . base . kind === "type" ) {
286- return { value : column . type . base . type , nullable : false } ;
287- }
288+ case type . kind === "union" && type . value . every ( ( x ) => x . kind === "literal" ) : {
289+ const resolved = getFromType ( type . value [ 0 ] . base ) ;
290+
291+ if ( resolved === null ) return null ;
292+
293+ return { value : resolved . value , nullable : false , array : false } ;
294+ }
295+
296+ case type . kind === "union" && isTuple ( type . value ) : {
297+ let nullable = false ;
298+ let value : string | undefined = undefined ;
299+
300+ for ( const valueType of type . value ) {
301+ if ( valueType . kind !== "type" ) return null ;
302+ if ( valueType . value === "null" ) nullable = true ;
303+ if ( valueType . value !== "null" ) value = valueType . type ;
304+ }
305+
306+ if ( value === undefined ) return null ;
288307
289- if ( column . type . kind === "union" && isTuple ( column . type . value ) ) {
290- let nullable = false ;
291- let value : string | undefined = undefined ;
308+ return { value, nullable, array : false } ;
309+ }
292310
293- for ( const type of column . type . value ) {
294- if ( type . kind !== "type" ) return null ;
295- if ( type . value === "null" ) nullable = true ;
296- if ( type . value !== "null" ) value = type . type ;
311+ default :
312+ return null ;
297313 }
314+ } ;
315+
316+ if ( column . type . kind === "array" ) {
317+ const resolved = getFromType ( column . type . value ) ;
298318
299- if ( value === undefined ) return null ;
319+ if ( ! resolved ) return null ;
300320
301- return { value, nullable } ;
321+ return { value : resolved . value , nullable : resolved . nullable , array : true } ;
302322 }
303323
304- return null ;
324+ return getFromType ( column . type ) ;
305325 } ;
306326
307327 const lnode = getResolvedNullableValueOrNull ( node . lexpr ) ;
308328 const rnode = getResolvedNullableValueOrNull ( node . rexpr ) ;
329+ const operator = concatStringNodes ( node . name ) ;
309330
310331 if ( lnode === null || rnode === null ) {
311332 return [ ] ;
312333 }
313334
314- const operator = concatStringNodes ( node . name ) ;
315- const resolved : string | undefined =
316- defaultTypeExprMapping [ `${ lnode . value } ${ operator } ${ rnode . value } ` ] ;
335+ const downcast = ( ) => {
336+ const left = lnode . array ? `_${ lnode . value } ` : lnode . value ;
337+ const right = rnode . array ? `_${ rnode . value } ` : rnode . value ;
338+
339+ const overrides : Record < string , [ string , string , string ] > = {
340+ "int4 ^ int4" : [ "float8" , "^" , "float8" ] ,
341+ } ;
342+
343+ if ( overrides [ `${ left } ${ operator } ${ right } ` ] ) {
344+ return overrides [ `${ left } ${ operator } ${ right } ` ] ;
345+ }
346+
347+ const adjust = ( value : string ) => ( value === "varchar" ? "text" : value ) ;
348+
349+ return [ adjust ( left ) , operator , adjust ( right ) ] ;
350+ } ;
317351
318- if ( resolved === undefined ) {
352+ const getType = ( ) : ASTDescribedColumnType | undefined => {
353+ const nullable = ! context . nonNullableColumns . has ( name ) && ( lnode . nullable || rnode . nullable ) ;
354+ const [ dleft , doperator , dright ] = downcast ( ) ;
355+
356+ const type =
357+ context . typeExprMap . get ( dleft ) ?. get ( doperator ) ?. get ( dright ) ??
358+ context . typeExprMap . get ( "anycompatiblearray" ) ?. get ( operator ) ?. get ( "anycompatiblearray" ) ??
359+ context . typeExprMap . get ( "anyarray" ) ?. get ( operator ) ?. get ( "anyarray" ) ??
360+ context . typeExprMap . get ( lnode . value ) ?. get ( operator ) ?. values ( ) . next ( ) . value ;
361+
362+ if ( type === undefined ) {
363+ return ;
364+ }
365+
366+ if ( type === "anycompatiblearray" ) {
367+ return {
368+ kind : "array" ,
369+ value : resolveType ( {
370+ context,
371+ nullable,
372+ type : context . toTypeScriptType ( { name : lnode . value } ) ,
373+ } ) ,
374+ } ;
375+ }
376+
377+ return resolveType ( {
378+ context,
379+ nullable,
380+ type : context . toTypeScriptType ( { name : type } ) ,
381+ } ) ;
382+ } ;
383+
384+ const type = getType ( ) ;
385+
386+ if ( type === undefined ) {
319387 return [ ] ;
320388 }
321389
322- return [
323- {
324- name : name ,
325- type : resolveType ( {
326- context : context ,
327- nullable : ! context . nonNullableColumns . has ( name ) && ( lnode . nullable || rnode . nullable ) ,
328- type : context . toTypeScriptType ( { name : resolved } ) ,
329- } ) ,
330- } ,
331- ] ;
390+ return [ { name, type } ] ;
332391}
333392
334393function getDescribedNullTest ( {
@@ -815,14 +874,14 @@ function getColumnRefOrigins({
815874 // lookup in cte
816875 context . select . withClause ?. ctes
817876 . find ( ( cte ) => cte . CommonTableExpr ?. ctename === source )
818- ?. CommonTableExpr ?. ctequery ?. SelectStmt ?. targetList ?. map ( ( x ) => x . ResTarget )
819- . find ( ( x ) => x ?. name === column ) ?. val ??
877+ ?. CommonTableExpr ?. ctequery ?. SelectStmt ?. targetList ?. find (
878+ ( x ) => x . ResTarget ?. name === column ,
879+ ) ??
820880 // lookup in subselect
821881 context . select . fromClause
822882 ?. map ( ( from ) => from . RangeSubselect )
823883 . find ( ( subselect ) => subselect ?. alias ?. aliasname === source )
824- ?. subquery ?. SelectStmt ?. targetList ?. map ( ( x ) => x . ResTarget )
825- . find ( ( x ) => x ?. name === column ) ?. val ;
884+ ?. subquery ?. SelectStmt ?. targetList ?. find ( ( x ) => x . ResTarget ?. name === column ) ;
826885
827886 if ( ! origin ) return undefined ;
828887
0 commit comments