@@ -47,6 +47,8 @@ import (
4747var (
4848 //go:embed templates/*.template
4949 templates embed.FS
50+
51+ packageNames sync.Map
5052)
5153
5254func usage () {
@@ -89,14 +91,23 @@ func main() {
8991 rawmessage : * rawmessage ,
9092 }
9193
94+ c , err := gelint .NewPool ("" , gelcfg.Options {})
95+ if err != nil {
96+ log .Fatalf ("creating client: %s" , err ) // nolint:gocritic
97+ }
98+
9299 timer := time .AfterFunc (200 * time .Millisecond , func () {
93- log .Println ("connecting to Gel" )
100+ log .Println ("connecting to Gel ... " )
94101 })
95- defer timer .Stop ()
96102
97- c , err := gelint .NewPool ("" , gelcfg.Options {})
103+ ctx := context .Background ()
104+ err = c .EnsureConnected (ctx )
98105 if err != nil {
99- log .Fatalf ("creating client: %s" , err ) // nolint:gocritic
106+ log .Fatalf ("connecting to Gel: %v" , err )
107+ }
108+
109+ if ! timer .Stop () {
110+ log .Println ("connected" )
100111 }
101112
102113 fileQueue := queueFilesInBackground ()
@@ -108,7 +119,6 @@ func main() {
108119 log .Fatal (err )
109120 }
110121
111- ctx := context .Background ()
112122 var wg sync.WaitGroup
113123 for queryFile := range fileQueue {
114124 wg .Add (1 )
@@ -231,6 +241,25 @@ func queueFilesInBackground() chan string {
231241 }
232242
233243 if ! d .IsDir () && strings .HasSuffix (f , ".edgeql" ) {
244+ dirname , err := getDirName (f )
245+ if err != nil {
246+ return err
247+ }
248+
249+ // Cache package names before any go files are written.
250+ // getPackageName reads the other go files in the directory
251+ // to find the package name. If we are writing go files in
252+ // this directory at the same time there is a race
253+ // condition where an empty or partially written file will
254+ // cause getPackageName to return an error because the file
255+ // is malformed.
256+ if _ , ok := packageNames .Load (dirname ); ! ok {
257+ packageName , err := getPackageName (dirname )
258+ if err != nil {
259+ return err
260+ }
261+ packageNames .Store (dirname , packageName )
262+ }
234263 queue <- f
235264 }
236265
@@ -252,9 +281,14 @@ func writeGoFile(
252281 outFile string ,
253282 queries []* Query ,
254283) error {
255- packageName , err := getPackageName (outFile )
284+ dirname , err := getDirName (outFile )
256285 if err != nil {
257- log .Fatal (err )
286+ return err
287+ }
288+
289+ packageName , ok := packageNames .Load (dirname )
290+ if ! ok {
291+ return fmt .Errorf ("no package name found for %q" , outFile )
258292 }
259293
260294 var imports []string
@@ -287,16 +321,20 @@ func writeGoFile(
287321 return nil
288322}
289323
290- // getPackageName looks up the package name from the first adjacent .go file it
291- // finds. If there are no adjacent .go files it uses the lower case version of
292- // the directory name as the package name.
293- func getPackageName (outFile string ) (string , error ) {
294- outFile , err := filepath .Abs (outFile )
324+ func getDirName (file string ) (string , error ) {
325+ file , err := filepath .Abs (file )
295326 if err != nil {
296327 return "" , err
297328 }
298329
299- dirname := filepath .Dir (outFile )
330+ dirname := filepath .Dir (file )
331+ return dirname , nil
332+ }
333+
334+ // getPackageName looks up the package name from the first adjacent .go file it
335+ // finds. If there are no adjacent .go files it uses the lower case version of
336+ // the directory name as the package name.
337+ func getPackageName (dirname string ) (string , error ) {
300338 entries , err := os .ReadDir (dirname )
301339 if err != nil {
302340 return "" , err
0 commit comments