|
4 | 4 | package db |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "bufio" |
7 | 8 | "context" |
8 | 9 | "embed" |
9 | 10 | "fmt" |
| 11 | + "io" |
10 | 12 | "os" |
| 13 | + "os/exec" |
11 | 14 | "regexp" |
| 15 | + "strings" |
| 16 | + "sync" |
12 | 17 | "testing" |
13 | 18 | "time" |
14 | 19 |
|
@@ -116,8 +121,14 @@ func RunMigrations(ctx context.Context, uri string) error { |
116 | 121 | } |
117 | 122 |
|
118 | 123 | func NewTransientDB(t *testing.T) (*spanner.Client, context.Context) { |
119 | | - // For now let's create a transient spanner DB. |
120 | | - // We could also spawn a custom spanner emulator per each test. |
| 124 | + // If the environment contains the emulator binary, start it. |
| 125 | + if bin := os.Getenv("SPANNER_EMULATOR_BIN"); bin != "" { |
| 126 | + host := spannerTestWrapper(t, bin) |
| 127 | + os.Setenv("SPANNER_EMULATOR_HOST", host) |
| 128 | + } else if os.Getenv("CI") != "" { |
| 129 | + // We do want to always run these tests on CI. |
| 130 | + t.Fatalf("CI is set, but SPANNER_EMULATOR_BIN is empty") |
| 131 | + } |
121 | 132 | if os.Getenv("SPANNER_EMULATOR_HOST") == "" { |
122 | 133 | t.Skip("SPANNER_EMULATOR_HOST must be set") |
123 | 134 | return nil, nil |
@@ -154,6 +165,66 @@ func NewTransientDB(t *testing.T) (*spanner.Client, context.Context) { |
154 | 165 | return client, ctx |
155 | 166 | } |
156 | 167 |
|
| 168 | +var setupSpannerOnce sync.Once |
| 169 | +var spannerHost string |
| 170 | + |
| 171 | +func spannerTestWrapper(t *testing.T, bin string) string { |
| 172 | + setupSpannerOnce.Do(func() { |
| 173 | + t.Logf("this could be the first test requiring a Spanner emulator, starting %s", bin) |
| 174 | + cmd, host, err := runSpanner(bin) |
| 175 | + if err != nil { |
| 176 | + t.Fatal(err) |
| 177 | + } |
| 178 | + spannerHost = host |
| 179 | + t.Cleanup(func() { |
| 180 | + cmd.Process.Kill() |
| 181 | + cmd.Wait() |
| 182 | + }) |
| 183 | + }) |
| 184 | + return spannerHost |
| 185 | +} |
| 186 | + |
| 187 | +var portRe = regexp.MustCompile(`Server address: ([\w:]+)`) |
| 188 | + |
| 189 | +func runSpanner(bin string) (*exec.Cmd, string, error) { |
| 190 | + cmd := exec.Command(bin, "--override_max_databases_per_instance=1000", |
| 191 | + "--grpc_port=0", "--http_port=0") |
| 192 | + stdout, err := cmd.StdoutPipe() |
| 193 | + if err != nil { |
| 194 | + return nil, "", err |
| 195 | + } |
| 196 | + cmd.Stderr = cmd.Stdout |
| 197 | + if err := cmd.Start(); err != nil { |
| 198 | + return nil, "", err |
| 199 | + } |
| 200 | + scanner := bufio.NewScanner(stdout) |
| 201 | + started, host := false, "" |
| 202 | + for scanner.Scan() { |
| 203 | + line := scanner.Text() |
| 204 | + if strings.Contains(line, "Cloud Spanner Emulator running") { |
| 205 | + started = true |
| 206 | + } else if parts := portRe.FindStringSubmatch(line); parts != nil { |
| 207 | + host = parts[1] |
| 208 | + } |
| 209 | + if started && host != "" { |
| 210 | + break |
| 211 | + } |
| 212 | + } |
| 213 | + if err := scanner.Err(); err != nil { |
| 214 | + return cmd, "", err |
| 215 | + } |
| 216 | + // The program may block if we don't read out all the remaining output. |
| 217 | + go io.Copy(io.Discard, stdout) |
| 218 | + |
| 219 | + if !started { |
| 220 | + return cmd, "", fmt.Errorf("the emulator did not print that it started") |
| 221 | + } |
| 222 | + if host == "" { |
| 223 | + return cmd, "", fmt.Errorf("did not detect the host") |
| 224 | + } |
| 225 | + return cmd, host, nil |
| 226 | +} |
| 227 | + |
157 | 228 | func readOne[T any](iter *spanner.RowIterator) (*T, error) { |
158 | 229 | row, err := iter.Next() |
159 | 230 | if err == iterator.Done { |
|
0 commit comments