Skip to content

Commit bc72fae

Browse files
authored
Move run and restart into the lifecycle manager (#368)
This refactors the code for running and restarting an MCP container out of the CLI and into the lifecycle management interface.
1 parent e8efa1b commit bc72fae

File tree

13 files changed

+362
-410
lines changed

13 files changed

+362
-410
lines changed

cmd/thv/app/common.go

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -67,15 +67,3 @@ func SetSecretsProvider(provider secrets.ProviderType) error {
6767
fmt.Printf("Secrets provider type updated to: %s\n", provider)
6868
return nil
6969
}
70-
71-
// NeedSecretsPassword returns true if the secrets provider requires a password.
72-
func NeedSecretsPassword(secretOptions []string) bool {
73-
// If the user did not ask for any secrets, then don't attempt to instantiate
74-
// the secrets manager.
75-
if len(secretOptions) == 0 {
76-
return false
77-
}
78-
// Ignore err - if the flag is not set, it's not needed.
79-
providerType, _ := config.GetConfig().Secrets.GetProviderType()
80-
return providerType == secrets.EncryptedType
81-
}

cmd/thv/app/list.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ func listCmdFunc(cmd *cobra.Command, _ []string) error {
6666
}
6767

6868
if len(toolHiveContainers) == 0 {
69-
logger.Infof("No MCP servers found")
69+
logger.Info("No MCP servers found")
7070
return nil
7171
}
7272

@@ -140,7 +140,7 @@ func printJSONOutput(containers []rt.ContainerInfo) error {
140140
}
141141

142142
// Print JSON
143-
logger.Infof(string(jsonData))
143+
logger.Info(string(jsonData))
144144
return nil
145145
}
146146

@@ -192,7 +192,7 @@ func printMCPServersOutput(containers []rt.ContainerInfo) error {
192192
}
193193

194194
// Print JSON
195-
logger.Infof(string(jsonData))
195+
logger.Info(string(jsonData))
196196
return nil
197197
}
198198

cmd/thv/app/proxy.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
6464

6565
// Create JWT validator if OIDC flags are provided
6666
if IsOIDCEnabled(cmd) {
67-
logger.Infof("OIDC validation enabled")
67+
logger.Info("OIDC validation enabled")
6868

6969
// Get OIDC flag values
7070
issuer := GetStringFlagOrEmpty(cmd, "oidc-issuer")
@@ -86,7 +86,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
8686
// Add JWT validation middleware
8787
middlewares = append(middlewares, jwtValidator.Middleware)
8888
} else {
89-
logger.Infof("OIDC validation disabled")
89+
logger.Info("OIDC validation disabled")
9090
}
9191

9292
// Create the transparent proxy
@@ -101,7 +101,7 @@ func proxyCmdFunc(cmd *cobra.Command, args []string) error {
101101

102102
logger.Infof("Transparent proxy started for server %s on port %d -> %s",
103103
serverName, port, proxyTargetURI)
104-
logger.Infof("Press Ctrl+C to stop")
104+
logger.Info("Press Ctrl+C to stop")
105105

106106
// Set up signal handling
107107
sigCh := make(chan os.Signal, 1)

cmd/thv/app/restart.go

Lines changed: 5 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,11 @@
11
package app
22

33
import (
4-
"context"
54
"fmt"
65

76
"github.com/spf13/cobra"
87

9-
"github.com/stacklok/toolhive/pkg/container"
10-
rt "github.com/stacklok/toolhive/pkg/container/runtime"
11-
"github.com/stacklok/toolhive/pkg/labels"
128
"github.com/stacklok/toolhive/pkg/lifecycle"
13-
"github.com/stacklok/toolhive/pkg/logger"
14-
"github.com/stacklok/toolhive/pkg/process"
15-
"github.com/stacklok/toolhive/pkg/runner"
169
)
1710

1811
var restartCmd = &cobra.Command{
@@ -32,160 +25,12 @@ func restartCmdFunc(cmd *cobra.Command, args []string) error {
3225
// Get container name
3326
containerName := args[0]
3427

35-
// Create container runtime
36-
runtime, err := container.NewFactory().Create(ctx)
28+
// Create lifecycle manager.
29+
manager, err := lifecycle.NewManager(ctx)
3730
if err != nil {
38-
return fmt.Errorf("failed to create container runtime: %v", err)
31+
return fmt.Errorf("failed to create lifecycle manager: %v", err)
3932
}
4033

41-
// Try to find the container ID
42-
containerID, err := findContainerID(ctx, runtime, containerName)
43-
var containerBaseName string
44-
var running bool
45-
46-
if err != nil {
47-
logger.Warnf("Warning: Failed to find container: %v", err)
48-
logger.Warnf("Trying to find state with name %s directly...", containerName)
49-
50-
// Try to use the provided name as the base name
51-
containerBaseName = containerName
52-
running = false
53-
} else {
54-
// Container found, check if it's running
55-
running, err = runtime.IsContainerRunning(ctx, containerID)
56-
if err != nil {
57-
return fmt.Errorf("failed to check if container is running: %v", err)
58-
}
59-
60-
// Get the base container name
61-
containerBaseName, err = getContainerBaseName(ctx, runtime, containerID)
62-
if err != nil {
63-
logger.Warnf("Warning: Could not find base container name in labels: %v", err)
64-
logger.Warnf("Using provided name %s as base name", containerName)
65-
containerBaseName = containerName
66-
}
67-
}
68-
69-
// Check if the proxy process is running
70-
proxyRunning := isProxyRunning(containerBaseName)
71-
72-
if running && proxyRunning {
73-
logger.Infof("Container %s and proxy are already running", containerName)
74-
return nil
75-
}
76-
77-
// If the container is running but the proxy is not, stop the container first
78-
if containerID != "" && running && !proxyRunning {
79-
logger.Infof("Container %s is running but proxy is not. Stopping container...", containerName)
80-
if err := runtime.StopContainer(ctx, containerID); err != nil {
81-
return fmt.Errorf("failed to stop container: %v", err)
82-
}
83-
logger.Infof("Container %s stopped", containerName)
84-
}
85-
86-
// Load the configuration from the state store
87-
mcpRunner, err := loadRunnerFromState(ctx, containerBaseName, runtime)
88-
if err != nil {
89-
return fmt.Errorf("failed to load state for %s: %v", containerBaseName, err)
90-
}
91-
92-
logger.Infof("Loaded configuration from state for %s", containerBaseName)
93-
94-
// Run the tooling server
95-
logger.Infof("Starting tooling server %s...", containerName)
96-
return RunMCPServer(ctx, mcpRunner.Config, false)
97-
}
98-
99-
// isProxyRunning checks if the proxy process is running
100-
func isProxyRunning(containerBaseName string) bool {
101-
if containerBaseName == "" {
102-
return false
103-
}
104-
105-
// Try to read the PID file
106-
pid, err := process.ReadPIDFile(containerBaseName)
107-
if err != nil {
108-
return false
109-
}
110-
111-
// Check if the process exists and is running
112-
isRunning, err := process.FindProcess(pid)
113-
if err != nil {
114-
logger.Warnf("Warning: Error checking process: %v", err)
115-
return false
116-
}
117-
118-
return isRunning
119-
}
120-
121-
// loadRunnerFromState attempts to load a Runner from the state store
122-
func loadRunnerFromState(ctx context.Context, baseName string, runtime rt.Runtime) (*runner.Runner, error) {
123-
// Load the runner from the state store
124-
r, err := runner.LoadState(ctx, baseName)
125-
if err != nil {
126-
return nil, err
127-
}
128-
129-
// Update the runtime in the loaded configuration
130-
r.Config.Runtime = runtime
131-
132-
return r, nil
133-
}
134-
135-
/*
136-
* The following functions are duplicated in container/manager.go until
137-
* we can refactor the code to avoid this duplication.
138-
*/
139-
140-
// getContainerBaseName gets the base container name from the container labels
141-
func getContainerBaseName(ctx context.Context, runtime rt.Runtime, containerID string) (string, error) {
142-
containers, err := runtime.ListContainers(ctx)
143-
if err != nil {
144-
return "", fmt.Errorf("failed to list containers: %v", err)
145-
}
146-
147-
for _, c := range containers {
148-
if c.ID == containerID {
149-
return labels.GetContainerBaseName(c.Labels), nil
150-
}
151-
}
152-
153-
return "", fmt.Errorf("container %s not found", containerID)
154-
}
155-
156-
func findContainerID(ctx context.Context, runtime rt.Runtime, name string) (string, error) {
157-
c, err := findContainerByName(ctx, runtime, name)
158-
if err != nil {
159-
return "", err
160-
}
161-
return c.ID, nil
162-
}
163-
164-
func findContainerByName(ctx context.Context, runtime rt.Runtime, name string) (*rt.ContainerInfo, error) {
165-
// List containers to find the one with the given name
166-
containers, err := runtime.ListContainers(ctx)
167-
if err != nil {
168-
return nil, fmt.Errorf("failed to list containers: %v", err)
169-
}
170-
171-
// Find the container with the given name
172-
for _, c := range containers {
173-
// Check if the container is managed by ToolHive
174-
if !labels.IsToolHiveContainer(c.Labels) {
175-
continue
176-
}
177-
178-
// Check if the container name matches
179-
containerName := labels.GetContainerName(c.Labels)
180-
if containerName == "" {
181-
name = c.Name // Fallback to container name
182-
}
183-
184-
// Check if the name matches (exact match or prefix match)
185-
if containerName == name || c.ID == name {
186-
return &c, nil
187-
}
188-
}
189-
190-
return nil, fmt.Errorf("%w: %s", lifecycle.ErrContainerNotFound, name)
34+
// Restart the container in a detached process.
35+
return manager.RestartContainer(ctx, containerName)
19136
}

cmd/thv/app/run.go

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package app
22

33
import (
44
"context"
5-
"encoding/json"
65
"fmt"
76
"os"
87
"strings"
@@ -13,6 +12,7 @@ import (
1312

1413
"github.com/stacklok/toolhive/pkg/container"
1514
"github.com/stacklok/toolhive/pkg/container/runtime"
15+
"github.com/stacklok/toolhive/pkg/lifecycle"
1616
"github.com/stacklok/toolhive/pkg/logger"
1717
"github.com/stacklok/toolhive/pkg/permissions"
1818
"github.com/stacklok/toolhive/pkg/registry"
@@ -323,7 +323,7 @@ func applyRegistrySettings(
323323

324324
// Create a temporary file for the permission profile if not explicitly provided
325325
if !cmd.Flags().Changed("permission-profile") {
326-
permProfilePath, err := createPermissionProfileFile(serverName, server.Permissions, debugMode)
326+
permProfilePath, err := lifecycle.CreatePermissionProfileFile(serverName, server.Permissions)
327327
if err != nil {
328328
// Just log the error and continue with the default permission profile
329329
logger.Warnf("Warning: Failed to create permission profile file: %v", err)
@@ -417,33 +417,6 @@ func hasLatestTag(imageRef string) bool {
417417
return !isDigest
418418
}
419419

420-
// createPermissionProfileFile creates a temporary file with the permission profile
421-
func createPermissionProfileFile(serverName string, permProfile *permissions.Profile, debugMode bool) (string, error) {
422-
tempFile, err := os.CreateTemp("", fmt.Sprintf("toolhive-%s-permissions-*.json", serverName))
423-
if err != nil {
424-
return "", fmt.Errorf("failed to create temporary file: %v", err)
425-
}
426-
defer tempFile.Close()
427-
428-
// Get the temporary file path
429-
permProfilePath := tempFile.Name()
430-
431-
// Serialize the permission profile to JSON
432-
permProfileJSON, err := json.Marshal(permProfile)
433-
if err != nil {
434-
return "", fmt.Errorf("failed to serialize permission profile: %v", err)
435-
}
436-
437-
// Write the permission profile to the temporary file
438-
if _, err := tempFile.Write(permProfileJSON); err != nil {
439-
return "", fmt.Errorf("failed to write permission profile to file: %v", err)
440-
}
441-
442-
logDebug(debugMode, "Wrote permission profile to temporary file: %s", permProfilePath)
443-
444-
return permProfilePath, nil
445-
}
446-
447420
// logDebug logs a message if debug mode is enabled
448421
func logDebug(debugMode bool, format string, args ...interface{}) {
449422
if debugMode {

0 commit comments

Comments
 (0)