Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion cmd/svcinit/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,14 @@ func buildTestEnv(ports svclib.Ports) ([]string, error) {
panic(err)
}

replacements := make([]Replacement, 0, len(ports))
tmpDir := os.Getenv("TMPDIR")
socketDir := os.Getenv("SOCKET_DIR")

replacements := make([]Replacement, 0, 2+len(ports))
replacements = append(replacements,
Replacement{Old: "$${TMPDIR}", New: tmpDir},
Replacement{Old: "$${SOCKET_DIR}", New: socketDir},
)
for label, port := range ports {
replacements = append(replacements, Replacement{
Old: "$${" + label + "}",
Expand Down
14 changes: 12 additions & 2 deletions private/itest.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
_ServiceGroupInfo = provider(
doc = "Info about a service group",
fields = {
"deferred": "Flag if this service should be deferred or not",
"services": "Dict of services/tasks",
},
)
Expand Down Expand Up @@ -123,6 +124,11 @@ def _compute_env(ctx, underlying_target):
return env

def _itest_binary_impl(ctx, extra_service_spec_kwargs, extra_exe_runfiles = []):
if hasattr(ctx.attr, "deferred") and not ctx.attr.deferred:
for dep in ctx.attr.deps:
if dep[_ServiceGroupInfo].deferred:
fail("Non-deferred itest_service cannot depend on deferred itest_service: %s depends on %s" % (ctx.label, dep.label))

exe_runfiles = [ctx.attr.exe.default_runfiles] + extra_exe_runfiles

version_file_deps = ctx.files.data + ctx.files.exe
Expand Down Expand Up @@ -165,7 +171,7 @@ def _itest_binary_impl(ctx, extra_service_spec_kwargs, extra_exe_runfiles = []):
return [
RunEnvironmentInfo(environment = _run_environment(ctx, service_specs_file)),
DefaultInfo(runfiles = runfiles),
_ServiceGroupInfo(services = services),
_ServiceGroupInfo(services = services, deferred = getattr(ctx.attr, "deferred", False)),
]

def _validate_duration(name, s):
Expand Down Expand Up @@ -194,6 +200,7 @@ def _itest_service_impl(ctx):
"http_health_check_address": ctx.attr.http_health_check_address,
"autoassign_port": ctx.attr.autoassign_port,
"so_reuseport_aware": ctx.attr.so_reuseport_aware,
"deferred": ctx.attr.deferred,
"named_ports": ctx.attr.named_ports,
"hot_reloadable": ctx.attr.hot_reloadable,
"expected_start_duration": ctx.attr.expected_start_duration,
Expand Down Expand Up @@ -250,6 +257,9 @@ _itest_service_attrs = _itest_binary_attrs | {

Must only be set when `autoassign_port` is enabled or `named_ports` are used.""",
),
"deferred": attr.bool(
doc = """If set, the service manager will not be start on boot up. It can be started using the service manager's control API.""",
),
"expected_start_duration": attr.string(
default = "0s",
doc = "How long the service expected to take before passing a healthcheck. Any failing health checks before this duration elapses will not be logged.",
Expand Down Expand Up @@ -345,7 +355,7 @@ def _itest_service_group_impl(ctx):
return [
RunEnvironmentInfo(environment = _run_environment(ctx, service_specs_file)),
DefaultInfo(runfiles = runfiles),
_ServiceGroupInfo(services = services),
_ServiceGroupInfo(services = services, deferred = False),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think group also needs a deferred prop, otherwise you can have (non-deferred service) depend on (group) depend on (deferred service) and it will defeat the check you added

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same thing with tasks, right? let's just add the attribute on all of them

]

_itest_service_group_attrs = _svcinit_attrs | {
Expand Down
15 changes: 12 additions & 3 deletions runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,11 @@ func (r *Runner) StartAll(serviceErrCh chan error) ([]topological.Task, error) {
return nil
}

if service.Deferred {
log.Printf("Deferring %s\n", colorize(service.VersionedServiceSpec))
return nil
}

if terseOutput {
log.Printf("Starting %s\n", colorize(service.VersionedServiceSpec))
} else {
Expand Down Expand Up @@ -85,6 +90,10 @@ func (r *Runner) StartAll(serviceErrCh chan error) ([]topological.Task, error) {
continue
}

if service.Deferred {
continue
}

// TODO(zbarsky): Can remove the loop var once Go is sufficiently upgraded.
go func(service *ServiceInstance) {
err := service.Wait()
Expand All @@ -99,7 +108,7 @@ func (r *Runner) StartAll(serviceErrCh chan error) ([]topological.Task, error) {

func (r *Runner) StopAll() (map[string]*os.ProcessState, error) {
tasks := allTasks(r.serviceInstances, func(ctx context.Context, service *ServiceInstance) error {
if service.Type == "group" {
if service.Type == "group" || service.Deferred {
return nil
}
log.Printf("Stopping %s\n", colorize(service.VersionedServiceSpec))
Expand All @@ -111,7 +120,7 @@ func (r *Runner) StopAll() (map[string]*os.ProcessState, error) {
states := make(map[string]*os.ProcessState)

for _, serviceInstance := range r.serviceInstances {
if serviceInstance.Type == "group" {
if serviceInstance.Type == "group" || serviceInstance.Deferred {
continue
}
states[serviceInstance.Label] = serviceInstance.ProcessState()
Expand Down Expand Up @@ -263,7 +272,7 @@ func initializeServiceCmd(ctx context.Context, instance *ServiceInstance) error

// Even if a child process exits, Wait will block until the I/O pipes are closed.
// They may have been forwarded to an orphaned child, so we disable that behavior to unblock exit.
if s.Type == "service" {
if s.Type == "service" && !s.Deferred {
// We need a bit of grace period to allow I/O pipes to close on our end.
cmd.WaitDelay = 50 * time.Millisecond
}
Expand Down
38 changes: 34 additions & 4 deletions svcctl/svcctl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package svcctl

import (
"context"
"errors"
"fmt"
"log"
"net"
"net/http"
"os/exec"
Expand Down Expand Up @@ -58,13 +60,39 @@ func handleHealthCheck(ctx context.Context, r *runner.Runner, _ chan error, w ht
w.WriteHeader(http.StatusOK)
}

func colorize(s svclib.VersionedServiceSpec) string {
return s.Colorize(s.Label)
}

func handleStart(ctx context.Context, r *runner.Runner, serviceErrCh chan error, w http.ResponseWriter, req *http.Request) {
s, status, err := getService(r, req)
if err != nil {
http.Error(w, err.Error(), status)
return
}

if s.Deferred {
// make sure all the non-deferred dependencies are started
for _, dep := range s.Deps {
depService := r.GetInstance(dep)
if depService == nil {
http.Error(w, fmt.Sprintf("dependency %q not found", dep), http.StatusInternalServerError)
return
}

if depService.Deferred {
continue
}

depsErr := s.WaitUntilHealthy(ctx)
if depsErr != nil {
http.Error(w, fmt.Sprintf("Failed to wait for %q until healthy", dep), http.StatusInternalServerError)
}
}
}

log.Printf("Starting %s\n", colorize(s.VersionedServiceSpec))

err = s.Start(ctx)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand All @@ -74,8 +102,8 @@ func handleStart(ctx context.Context, r *runner.Runner, serviceErrCh chan error,
// NOTE: it is important to wait here because we started the service without using `StartAll`,
// which waits for processes to prevent them from turning into zombies.
go func() {
err := s.Wait()
if err != nil && !s.Killed() {
waitErr := s.Wait()
if waitErr != nil && !s.Killed() {
serviceErrCh <- fmt.Errorf(s.Colorize(s.Label) + " exited with error: " + err.Error())
}
}()
Expand Down Expand Up @@ -147,9 +175,11 @@ func handleWait(ctx context.Context, r *runner.Runner, _ chan error, w http.Resp
w.Write([]byte("0"))
return
}
if err, ok := err.(*exec.ExitError); ok {

var exitErr *exec.ExitError
if errors.As(err, &exitErr) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("%d", err.ExitCode())))
w.Write([]byte(fmt.Sprintf("%d", exitErr.ExitCode())))
return
}
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
1 change: 1 addition & 0 deletions svclib/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type ServiceSpec struct {
ShutdownSignal string `json:"shutdown_signal"`
ShutdownTimeout string `json:"shutdown_timeout"`
EnforceForcefulShutdown bool `json:"enforce_graceful_shutdown"`
Deferred bool `json:"deferred"`
}

// Our internal representation.
Expand Down
60 changes: 60 additions & 0 deletions tests/deferred/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
load("@rules_go//go:def.bzl", "go_binary", "go_library", "go_test")
load("@rules_itest//:itest.bzl", "itest_service", "itest_task", "service_test")
load(":tests.bzl", "tests")

tests()

go_library(
name = "deferred_lib",
srcs = ["deferred_service.go"],
importpath = "rules_itest/tests/deferred",
visibility = ["//visibility:private"],
)

go_binary(
name = "deferred",
embed = [":deferred_lib"],
visibility = ["//visibility:public"],
)

go_test(
name = "deferred_test",
srcs = ["start_deferred_service_test.go"],
embed = [":deferred_lib"],
tags = ["manual"],
deps = ["//svcctl"],
)

service_test(
name = "deferred_service_test",
services = [":deferred_itest_service"],
test = ":deferred_test",
)

itest_service(
name = "deferred_itest_service",
args = [
"$${PORT}",
],
autoassign_port = True,
deferred = True,
exe = ":deferred",
http_health_check_address = "http://localhost:$${PORT}/healthz",
tags = ["requires-network"],
)

itest_service(
name = "deferred_task",
deferred = True,
exe = "@rules_itest//:exit0",
hygienic = False,
)

itest_service(
name = "non_deferred_depends_on_deferred_should_fail",
exe = "@rules_itest//:exit0",
tags = ["manual"],
deps = [
":deferred_task",
],
)
95 changes: 95 additions & 0 deletions tests/deferred/deferred_service.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package main

import (
"context"
"encoding/json"
"log"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"syscall"
"time"
)

var (
mu sync.RWMutex
value string
)

func main() {
mux := http.NewServeMux()
mux.HandleFunc("/healthz", healthHandler)
mux.HandleFunc("/update", updateHandler)
mux.HandleFunc("/value", valueHandler)

port, err := strconv.ParseInt(os.Args[1], 10, 64)
if err != nil {
log.Fatalf("Invalid port: %v", err)
}

server := &http.Server{
Addr: "0.0.0.0:" + strconv.FormatInt(port, 10),
Handler: mux,
}

// Listen for SIGTERM for graceful shutdown
stop := make(chan os.Signal, 1)
signal.Notify(stop, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)

go func() {
log.Println("Server starting on :" + strconv.FormatInt(port, 10))
if err := server.ListenAndServe(); err != http.ErrServerClosed {
log.Fatalf("ListenAndServe: %v", err)
}
}()

<-stop
log.Println("Shutting down...")

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

if err := server.Shutdown(ctx); err != nil {
log.Fatalf("Server Shutdown Failed:%+v", err)
}
log.Println("Server exited gracefully")
}

func healthHandler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`ok`))
}

func updateHandler(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "Only POST allowed", http.StatusMethodNotAllowed)
return
}
var payload struct {
Value string `json:"value"`
}
err := json.NewDecoder(r.Body).Decode(&payload)
if err != nil {
http.Error(w, "Invalid JSON", http.StatusBadRequest)
return
}
mu.Lock()
value = payload.Value
mu.Unlock()
w.WriteHeader(http.StatusOK)
w.Write([]byte(`updated`))
}

func valueHandler(w http.ResponseWriter, r *http.Request) {
mu.RLock()
defer mu.RUnlock()
resp := struct {
Value string `json:"value"`
}{
Value: value,
}
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
Loading