175175 '''
176176
177177class Generator (FilePrinter ):
178- def __init__ (self , args : argparse .Namespace ):
178+ def __init__ (self , args : argparse .Namespace , client = None ):
179179 self ._default_module = "default"
180180 self ._targets = args .target
181181 self ._async = False
182- try :
183- self ._project_dir = pathlib .Path (find_gel_project_dir ())
184- except gel .ClientConnectionError :
185- print (
186- "Cannot find gel.toml: "
187- "codegen must be run under an EdgeDB project dir"
188- )
189- sys .exit (2 )
190- print_msg (f"Found EdgeDB project: { C .BOLD } { self ._project_dir } { C .ENDC } " )
191- self ._client = gel .create_client (** _get_conn_args (args ))
182+ if client is not None :
183+ self ._client = client
184+ else :
185+ self ._client = gel .create_client (** _get_conn_args (args ))
192186 self ._describe_results = []
193187
194188 self ._cache = {}
@@ -197,11 +191,13 @@ def __init__(self, args: argparse.Namespace):
197191 self ._defs = {}
198192 self ._names = set ()
199193
200- self ._basemodule = 'models'
201- self ._outdir = pathlib .Path ('models' )
194+ self ._basemodule = args . mod
195+ self ._outdir = pathlib .Path (args . out )
202196 self ._modules = {}
203197 self ._types = {}
204198
199+ self .init_dir (self ._outdir )
200+
205201 super ().__init__ ()
206202
207203 def run (self ):
@@ -285,74 +281,117 @@ def write_types(self, maps):
285281 scalar_types = maps ['scalar_types' ]
286282
287283 if object_types :
288- self .write (f'from typing import Optional, Any, Annotated' )
284+ self .write (f'import pydantic' )
285+ self .write (f'import typing as pt' )
286+ self .write (f'import uuid' )
289287 self .write (f'from gel.compatibility import pydmodels as gm' )
290288
291289 objects = sorted (
292290 object_types .values (), key = lambda x : x .name
293291 )
294292 for obj in objects :
293+ self .render_type (obj , variant = 'Base' )
294+ self .render_type (obj , variant = 'Update' )
295295 self .render_type (obj )
296296
297- def render_type (self , objtype ):
297+ def render_type (self , objtype , * , variant = None ):
298298 mod , name = get_mod_and_name (objtype .name )
299+ is_empty = True
299300
300301 self .write ()
301302 self .write ()
302- self .write (f'class { name } (gm.BaseGelModel):' )
303- self .indent ()
304- self .write (f'__gel_name__ = { objtype .name !r} ' )
305-
306- if len (objtype .properties ) > 0 :
303+ match variant :
304+ case 'Base' :
305+ self .write (f'class _{ variant } { name } (gm.BaseGelModel):' )
306+ self .indent ()
307+ self .write (f'__gel_name__ = { objtype .name !r} ' )
308+ case 'Update' :
309+ self .write (f'class _{ variant } { name } (gm.UpdateGelModel):' )
310+ self .indent ()
311+ self .write (f'__gel_name__ = { objtype .name !r} ' )
312+ self .write (
313+ f"id: pt.Annotated[uuid.UUID, gm.GelType('std::uuid'), "
314+ f"gm.Exclusive]"
315+ )
316+ case _:
317+ self .write (f'class { name } (_Base{ name } ):' )
318+ self .indent ()
319+
320+ if variant and len (objtype .properties ) > 0 :
321+ is_empty = False
307322 self .write ()
308323 self .write ('# Properties:' )
309324 for prop in objtype .properties :
310- self .render_prop (prop , mod )
325+ self .render_prop (prop , mod , variant = variant )
311326
312- if len (objtype .links ) > 0 :
313- self .write ()
314- self .write ('# Properties:' )
327+ if variant != 'Base' and len (objtype .links ) > 0 :
328+ if variant or not is_empty :
329+ self .write ()
330+ is_empty = False
331+ self .write ('# Links:' )
315332 for link in objtype .links :
316- self .render_link (link , mod )
333+ self .render_link (link , mod , variant = variant )
334+
335+ if not variant :
336+ if not is_empty :
337+ self .write ()
338+ self .write ('# Class variants:' )
339+ self .write (f'base: pt.ClassVar = _Base{ name } ' )
340+ self .write (f'update: pt.ClassVar = _Update{ name } ' )
317341
318342 self .dedent ()
319343
320- def render_prop (self , prop , curmod ):
344+ def render_prop (self , prop , curmod , * , variant = None ):
321345 pytype = TYPE_MAPPING .get (prop .target .name )
346+ annotated = [f'gm.GelType({ prop .target .name !r} )' ]
322347 defval = ''
323348 if not pytype :
324349 # skip
325350 return
326351
327- # FIXME: need to also handle multi
352+ if str (prop .cardinality ) == 'Many' :
353+ annotated .append ('gm.Multi' )
354+ pytype = f'pt.List[{ pytype } ]'
355+ defval = ' = []'
328356
329- if not prop .required :
330- pytype = f'Optional[{ pytype } ]'
357+ if variant == 'Update' or not prop .required :
358+ pytype = f'pt. Optional[{ pytype } ]'
331359 # A value does not need to be supplied
332360 defval = ' = None'
333361
334362 if prop .exclusive :
335- pytype = f'Annotated[{ pytype } , gm.Exclusive]'
363+ annotated .append ('gm.Exclusive' )
364+
365+ anno = ', ' .join ([pytype ] + annotated )
366+ pytype = f'pt.Annotated[{ anno } ]'
336367
337368 self .write (
338369 f'{ prop .name } : { pytype } { defval } '
339370 )
340371
341- def render_link (self , link , curmod ):
372+ def render_link (self , link , curmod , * , variant = None ):
342373 mod , name = get_mod_and_name (link .target .name )
374+ annotated = [f'gm.GelType({ link .target .name !r} )' , 'gm.Link' ]
375+ defval = ''
343376 if curmod == mod :
344377 pytype = name
345378 else :
346379 pytype = link .target .name .replace ('::' , '.' )
380+ pytype = repr (pytype )
347381
348- # FIXME: need to also handle multi
382+ if str (link .cardinality ) == 'Many' :
383+ annotated .append ('gm.Multi' )
384+ pytype = f'pt.List[{ pytype } ]'
385+ defval = ' = []'
349386
350- if link .required :
351- self .write (
352- f'{ link .name } : { pytype !r} '
353- )
354- else :
387+ if variant == 'Update' or not link .required :
388+ pytype = f'pt.Optional[{ pytype } ]'
355389 # A value does not need to be supplied
356- self .write (
357- f'{ link .name } : Optional[{ pytype !r} ] = None'
358- )
390+ defval = ' = None'
391+
392+ anno = ', ' .join ([pytype ] + annotated )
393+ pytype = f'pt.Annotated[{ anno } ]'
394+
395+ self .write (
396+ f'{ link .name } : { pytype } { defval } '
397+ )
0 commit comments