Skip to content

Commit 4a2f392

Browse files
authored
Fix edgeql-go file write race condition (#395)
To find the package name, edgeql-go reads the other go files in the directory. This was happening at the same time that generated files in that directory were being written. If a file being written in one go routine was read in another it would not parse because it was incomplete. This caused edgeql-go to crash. The package name is now determined only once for a directory before processing any of the queries in that directory.
1 parent 6497b85 commit 4a2f392

File tree

1 file changed

+51
-13
lines changed

1 file changed

+51
-13
lines changed

cmd/edgeql-go/main.go

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ import (
4747
var (
4848
//go:embed templates/*.template
4949
templates embed.FS
50+
51+
packageNames sync.Map
5052
)
5153

5254
func usage() {
@@ -89,14 +91,23 @@ func main() {
8991
rawmessage: *rawmessage,
9092
}
9193

94+
c, err := gelint.NewPool("", gelcfg.Options{})
95+
if err != nil {
96+
log.Fatalf("creating client: %s", err) // nolint:gocritic
97+
}
98+
9299
timer := time.AfterFunc(200*time.Millisecond, func() {
93-
log.Println("connecting to Gel")
100+
log.Println("connecting to Gel ...")
94101
})
95-
defer timer.Stop()
96102

97-
c, err := gelint.NewPool("", gelcfg.Options{})
103+
ctx := context.Background()
104+
err = c.EnsureConnected(ctx)
98105
if err != nil {
99-
log.Fatalf("creating client: %s", err) // nolint:gocritic
106+
log.Fatalf("connecting to Gel: %v", err)
107+
}
108+
109+
if !timer.Stop() {
110+
log.Println("connected")
100111
}
101112

102113
fileQueue := queueFilesInBackground()
@@ -108,7 +119,6 @@ func main() {
108119
log.Fatal(err)
109120
}
110121

111-
ctx := context.Background()
112122
var wg sync.WaitGroup
113123
for queryFile := range fileQueue {
114124
wg.Add(1)
@@ -231,6 +241,25 @@ func queueFilesInBackground() chan string {
231241
}
232242

233243
if !d.IsDir() && strings.HasSuffix(f, ".edgeql") {
244+
dirname, err := getDirName(f)
245+
if err != nil {
246+
return err
247+
}
248+
249+
// Cache package names before any go files are written.
250+
// getPackageName reads the other go files in the directory
251+
// to find the package name. If we are writing go files in
252+
// this directory at the same time there is a race
253+
// condition where an empty or partially written file will
254+
// cause getPackageName to return an error because the file
255+
// is malformed.
256+
if _, ok := packageNames.Load(dirname); !ok {
257+
packageName, err := getPackageName(dirname)
258+
if err != nil {
259+
return err
260+
}
261+
packageNames.Store(dirname, packageName)
262+
}
234263
queue <- f
235264
}
236265

@@ -252,9 +281,14 @@ func writeGoFile(
252281
outFile string,
253282
queries []*Query,
254283
) error {
255-
packageName, err := getPackageName(outFile)
284+
dirname, err := getDirName(outFile)
256285
if err != nil {
257-
log.Fatal(err)
286+
return err
287+
}
288+
289+
packageName, ok := packageNames.Load(dirname)
290+
if !ok {
291+
return fmt.Errorf("no package name found for %q", outFile)
258292
}
259293

260294
var imports []string
@@ -287,16 +321,20 @@ func writeGoFile(
287321
return nil
288322
}
289323

290-
// getPackageName looks up the package name from the first adjacent .go file it
291-
// finds. If there are no adjacent .go files it uses the lower case version of
292-
// the directory name as the package name.
293-
func getPackageName(outFile string) (string, error) {
294-
outFile, err := filepath.Abs(outFile)
324+
func getDirName(file string) (string, error) {
325+
file, err := filepath.Abs(file)
295326
if err != nil {
296327
return "", err
297328
}
298329

299-
dirname := filepath.Dir(outFile)
330+
dirname := filepath.Dir(file)
331+
return dirname, nil
332+
}
333+
334+
// getPackageName looks up the package name from the first adjacent .go file it
335+
// finds. If there are no adjacent .go files it uses the lower case version of
336+
// the directory name as the package name.
337+
func getPackageName(dirname string) (string, error) {
300338
entries, err := os.ReadDir(dirname)
301339
if err != nil {
302340
return "", err

0 commit comments

Comments
 (0)