diff --git a/checkpointctl.go b/checkpointctl.go index af7bd57d..c26aa1e1 100644 --- a/checkpointctl.go +++ b/checkpointctl.go @@ -29,6 +29,7 @@ func main() { rootCommand.AddCommand(cmd.List()) rootCommand.AddCommand(cmd.BuildCmd()) rootCommand.AddCommand(cmd.PluginCmd()) + rootCommand.AddCommand(cmd.EditCmd()) // Discover and register external plugins from PATH. // Plugins are executables named checkpointctl- where diff --git a/cmd/edit.go b/cmd/edit.go new file mode 100644 index 00000000..57f16aae --- /dev/null +++ b/cmd/edit.go @@ -0,0 +1,42 @@ +package cmd + +import ( + "fmt" + + "github.com/checkpoint-restore/checkpointctl/internal" + "github.com/spf13/cobra" +) + +var tcpListenRemapFlag string + +func EditCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "edit ", + Short: "Edit a checkpoint archive", + Long: `The 'edit' command can help you change the properties of a container inside a checkpoint archive. +Currently only supports remapping the TCP listen ports. +Example: + checkpointctl edit --tcp-listen-remap 8080:80 checkpoint.tar`, + Args: cobra.ExactArgs(1), + RunE: editArchive, + } + + cmd.Flags().StringVar( + &tcpListenRemapFlag, + "tcp-listen-remap", + "", + "Remap TCP listen port (format: oldport:newport)", + ) + + return cmd +} + +func editArchive(cmd *cobra.Command, args []string) error { + archivePath := args[0] + + if tcpListenRemapFlag != "" { + return internal.TcpListenRemap(tcpListenRemapFlag, archivePath) + } + + return fmt.Errorf("no edit operation specified; use --tcp-listen-remap") +} diff --git a/internal/archive_modifiers.go b/internal/archive_modifiers.go new file mode 100644 index 00000000..10bbf40a --- /dev/null +++ b/internal/archive_modifiers.go @@ -0,0 +1,237 @@ +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "archive/tar" + "bytes" + "encoding/json" + "fmt" + "io" + "os" + "strconv" + + "github.com/checkpoint-restore/go-criu/v8/crit" + "github.com/checkpoint-restore/go-criu/v8/crit/images/fdinfo" +) + +// TCP_LISTEN state value from the Linux kernel +const tcpListenState = 10 + +// remapFilesImg decodes a CRIU files.img binary image, remaps the source port +// of any TCP listen socket matching oldPort to newPort, and re-encodes the image. +func remapFilesImg(hdr *tar.Header, content io.Reader, oldPort, newPort uint32) (*tar.Header, []byte, error) { + // crit.New requires *os.File, so write the tar entry content to a temp file + tmpIn, err := os.CreateTemp("", "files-img-in-*.img") + if err != nil { + return nil, nil, fmt.Errorf("creating temp input file: %w", err) + } + defer os.Remove(tmpIn.Name()) + defer tmpIn.Close() + + if _, err := io.Copy(tmpIn, content); err != nil { + return nil, nil, fmt.Errorf("writing to temp file: %w", err) + } + if _, err := tmpIn.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("seeking temp file: %w", err) + } + + // Decode the binary image + c := crit.New(tmpIn, nil, "", false, false) + img, err := c.Decode(&fdinfo.FileEntry{}) + if err != nil { + return nil, nil, fmt.Errorf("decoding files.img: %w", err) + } + + // Walk every entry looking for TCP listen sockets on the old port + remapped := 0 + for _, entry := range img.Entries { + fileEntry, ok := entry.Message.(*fdinfo.FileEntry) + if !ok { + continue + } + if fileEntry.GetType() != fdinfo.FdTypes_INETSK { + continue + } + isk := fileEntry.GetIsk() + if isk == nil { + continue + } + if isk.GetState() == tcpListenState && isk.GetSrcPort() == oldPort { + np := newPort + isk.SrcPort = &np + remapped++ + } + } + + if remapped == 0 { + return nil, nil, fmt.Errorf("no TCP listen sockets found with source port %d", oldPort) + } + + // Encode the modified image to another temp file + tmpOut, err := os.CreateTemp("", "files-img-out-*.img") + if err != nil { + return nil, nil, fmt.Errorf("creating temp output file: %w", err) + } + defer os.Remove(tmpOut.Name()) + defer tmpOut.Close() + + cOut := crit.New(nil, tmpOut, "", false, false) + if err := cOut.Encode(img); err != nil { + return nil, nil, fmt.Errorf("encoding files.img: %w", err) + } + + // Read the re-encoded bytes + if _, err := tmpOut.Seek(0, 0); err != nil { + return nil, nil, fmt.Errorf("seeking output file: %w", err) + } + var buf bytes.Buffer + if _, err := io.Copy(&buf, tmpOut); err != nil { + return nil, nil, fmt.Errorf("reading output file: %w", err) + } + + // Update the tar header to reflect the new size + hdr.Size = int64(buf.Len()) + return hdr, buf.Bytes(), nil +} + +// remapConfigDump modifies the config dump in a Podman checkpoint to update: +// - Port mappings +// - PORT environment variable in any nested env arrays +// Returns silently for other runtime checkpoints. +func remapConfigDump(hdr *tar.Header, content io.Reader, oldPort, newPort string) (*tar.Header, []byte, error) { + data, err := io.ReadAll(content) + if err != nil { + return nil, nil, fmt.Errorf("reading config.dump: %w", err) + } + + // Parse into a generic map to preserve all fields + var config map[string]any + if err := json.Unmarshal(data, &config); err != nil { + return nil, nil, fmt.Errorf("parsing config.dump JSON: %w", err) + } + + remapPortMappings(config, oldPort, newPort) + + remapEnvRecursive(config, oldPort, newPort) + + output, err := json.Marshal(config) + if err != nil { + return nil, nil, fmt.Errorf("marshaling config.dump: %w", err) + } + + hdr.Size = int64(len(output)) + return hdr, output, nil +} + +// remapSpecDump modifies the OCI runtime spec JSON to update the PORT env var. +func remapSpecDump(hdr *tar.Header, content io.Reader, oldPort, newPort string) (*tar.Header, []byte, error) { + data, err := io.ReadAll(content) + if err != nil { + return nil, nil, fmt.Errorf("reading spec.dump: %w", err) + } + + var spec map[string]any + if err := json.Unmarshal(data, &spec); err != nil { + return nil, nil, fmt.Errorf("parsing spec.dump JSON: %w", err) + } + + // The env array lives under spec.process.env + if process, ok := spec["process"].(map[string]any); ok { + if envSlice, ok := process["env"].([]any); ok { + process["env"] = remapEnvSlice(envSlice, oldPort, newPort) + } + } + + output, err := json.Marshal(spec) + if err != nil { + return nil, nil, fmt.Errorf("marshaling spec.dump: %w", err) + } + + hdr.Size = int64(len(output)) + return hdr, output, nil +} + +// remapPortMappings updates the container_port field in the objects of +// newPortMappings array in obj. It searches for port mappings where +// container_port matches oldPort and replaces them with newPort. +func remapPortMappings(obj any, oldPort, newPort string) { + m, ok := obj.(map[string]any) + if !ok { + return + } + + mappings, ok := m["newPortMappings"] + if !ok { + return + } + + mappingsSlice, ok := mappings.([]any) + if !ok { + return + } + + for _, mapping := range mappingsSlice { + mappingMap, ok := mapping.(map[string]any) + if !ok { + continue + } + + containerPort, ok := mappingMap["container_port"] + if !ok { + continue + } + + // JSON numbers are unmarshaled as float64 + portFloat, ok := containerPort.(float64) + if !ok { + continue + } + + if strconv.FormatFloat(portFloat, 'f', -1, 64) == oldPort { + newPortNum, _ := strconv.ParseFloat(newPort, 64) + mappingMap["container_port"] = newPortNum + } + } +} + +// remapEnvRecursive walks the structure obj looking for any "env" key +// whose value is an array of strings, and replaces PORT=oldPort with PORT=newPort. +func remapEnvRecursive(obj any, oldPort, newPort string) { + m, ok := obj.(map[string]any) + if !ok { + return + } + for key, val := range m { + if key == "env" { + if envSlice, ok := val.([]any); ok { + m["env"] = remapEnvSlice(envSlice, oldPort, newPort) + } + } else { + switch child := val.(type) { + case map[string]any: + remapEnvRecursive(child, oldPort, newPort) + case []any: + for _, item := range child { + remapEnvRecursive(item, oldPort, newPort) + } + } + } + } +} + +// remapEnvSlice replaces PORT=oldPort with PORT=newPort in an env slice. +func remapEnvSlice(envSlice []any, oldPort, newPort string) []any { + target := "PORT=" + oldPort + replacement := "PORT=" + newPort + for i, v := range envSlice { + s, ok := v.(string) + if !ok { + continue + } + if s == target { + envSlice[i] = replacement + } + } + return envSlice +} diff --git a/internal/archive_modifiers_test.go b/internal/archive_modifiers_test.go new file mode 100644 index 00000000..07d79d63 --- /dev/null +++ b/internal/archive_modifiers_test.go @@ -0,0 +1,532 @@ +package internal + +import ( + "archive/tar" + "bytes" + "encoding/json" + "testing" +) + +func TestRemapEnvSlice(t *testing.T) { + tests := []struct { + name string + env []any + oldPort string + newPort string + expected []any + }{ + { + name: "basic port replacement", + env: []any{"A=1", "PORT=5000", "B=2"}, + oldPort: "5000", + newPort: "9999", + expected: []any{"A=1", "PORT=9999", "B=2"}, + }, + { + name: "non-string element preserved", + env: []any{"PORT=5000", 123}, + oldPort: "5000", + newPort: "9999", + expected: []any{"PORT=9999", 123}, + }, + { + name: "no match - no change", + env: []any{"PORT=8080", "FOO=bar"}, + oldPort: "5000", + newPort: "9999", + expected: []any{"PORT=8080", "FOO=bar"}, + }, + { + name: "multiple PORT entries", + env: []any{"PORT=5000", "OTHER=x", "PORT=5000"}, + oldPort: "5000", + newPort: "9999", + expected: []any{"PORT=9999", "OTHER=x", "PORT=9999"}, + }, + { + name: "empty slice", + env: []any{}, + oldPort: "5000", + newPort: "9999", + expected: []any{}, + }, + { + name: "PORT as substring not replaced", + env: []any{"MYPORT=5000", "EXPORT=5000"}, + oldPort: "5000", + newPort: "9999", + expected: []any{"MYPORT=5000", "EXPORT=5000"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := remapEnvSlice(tt.env, tt.oldPort, tt.newPort) + if len(got) != len(tt.expected) { + t.Errorf("Expected env length %d, got %d", len(tt.expected), len(got)) + } + for i := range tt.expected { + if got[i] != tt.expected[i] { + t.Errorf("env[%d]: expected %v, got %v", i, tt.expected[i], got[i]) + } + } + }) + } +} + +func TestRemapEnvRecursive(t *testing.T) { + tests := []struct { + name string + obj any + oldPort string + newPort string + check func(t *testing.T, obj any) + }{ + { + name: "nested env remapping", + obj: map[string]any{ + "env": []any{"PORT=5000", "FOO=bar"}, + "process": map[string]any{ + "env": []any{"A=1", "PORT=5000"}, + }, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + rootEnv := m["env"].([]any) + if rootEnv[0] != "PORT=9999" { + t.Errorf("Expected root env PORT to be remapped, got %v", rootEnv[0]) + } + nestedEnv := m["process"].(map[string]any)["env"].([]any) + if nestedEnv[1] != "PORT=9999" { + t.Errorf("Expected nested env PORT to be remapped, got %v", nestedEnv[1]) + } + }, + }, + { + name: "deeply nested env", + obj: map[string]any{ + "level1": map[string]any{ + "level2": map[string]any{ + "env": []any{"PORT=5000"}, + }, + }, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + env := m["level1"].(map[string]any)["level2"].(map[string]any)["env"].([]any) + if env[0] != "PORT=9999" { + t.Errorf("Expected deeply nested PORT to be remapped, got %v", env[0]) + } + }, + }, + { + name: "env in array elements", + obj: map[string]any{ + "containers": []any{ + map[string]any{"env": []any{"PORT=5000"}}, + map[string]any{"env": []any{"PORT=5000"}}, + }, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + containers := m["containers"].([]any) + env0 := containers[0].(map[string]any)["env"].([]any) + env1 := containers[1].(map[string]any)["env"].([]any) + if env0[0] != "PORT=9999" || env1[0] != "PORT=9999" { + t.Errorf("Expected all container envs to be remapped") + } + }, + }, + { + name: "no env field - no error", + obj: map[string]any{"other": "data"}, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + if m["other"] != "data" { + t.Errorf("Expected object to remain unchanged") + } + }, + }, + { + name: "non-map input - no panic", + obj: "not a map", + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) {}, + }, + { + name: "env is not array - no panic", + obj: map[string]any{ + "env": "not an array", + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + if m["env"] != "not an array" { + t.Errorf("Expected env to remain unchanged") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + remapEnvRecursive(tt.obj, tt.oldPort, tt.newPort) + tt.check(t, tt.obj) + }) + } +} + +func TestRemapPortMappings(t *testing.T) { + tests := []struct { + name string + obj any + oldPort string + newPort string + check func(t *testing.T, obj any) + }{ + { + name: "basic port mapping", + obj: map[string]any{ + "newPortMappings": []any{ + map[string]any{"container_port": float64(5000)}, + map[string]any{"container_port": float64(6000)}, + }, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + mappings := m["newPortMappings"].([]any) + if mappings[0].(map[string]any)["container_port"].(float64) != 9999 { + t.Errorf("Expected first port to be 9999") + } + if mappings[1].(map[string]any)["container_port"].(float64) != 6000 { + t.Errorf("Expected second port to remain 6000") + } + }, + }, + { + name: "port with additional fields", + obj: map[string]any{ + "newPortMappings": []any{ + map[string]any{ + "host_ip": "", + "container_port": float64(8000), + "host_port": float64(8080), + "range": float64(1), + "protocol": "tcp", + }, + }, + }, + oldPort: "8000", + newPort: "9000", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + mappings := m["newPortMappings"].([]any) + mapping := mappings[0].(map[string]any) + if mapping["container_port"].(float64) != 9000 { + t.Errorf("Expected container_port to be 9000") + } + if mapping["host_port"].(float64) != 8080 { + t.Errorf("Expected host_port to remain 8080") + } + if mapping["protocol"] != "tcp" { + t.Errorf("Expected other fields to be preserved") + } + }, + }, + { + name: "no newPortMappings field", + obj: map[string]any{"other": "data"}, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + if m["other"] != "data" { + t.Errorf("Expected object to remain unchanged") + } + }, + }, + { + name: "empty newPortMappings array", + obj: map[string]any{ + "newPortMappings": []any{}, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + mappings := m["newPortMappings"].([]any) + if len(mappings) != 0 { + t.Errorf("Expected empty array to remain empty") + } + }, + }, + { + name: "non-map input", + obj: "not a map", + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) {}, + }, + { + name: "newPortMappings not an array", + obj: map[string]any{ + "newPortMappings": "not an array", + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + if m["newPortMappings"] != "not an array" { + t.Errorf("Expected field to remain unchanged") + } + }, + }, + { + name: "mapping without container_port field", + obj: map[string]any{ + "newPortMappings": []any{ + map[string]any{"host_port": float64(8080)}, + }, + }, + oldPort: "5000", + newPort: "9999", + check: func(t *testing.T, obj any) { + m := obj.(map[string]any) + mappings := m["newPortMappings"].([]any) + mapping := mappings[0].(map[string]any) + if _, exists := mapping["container_port"]; exists { + t.Errorf("Expected no container_port to be added") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + remapPortMappings(tt.obj, tt.oldPort, tt.newPort) + tt.check(t, tt.obj) + }) + } +} + +func TestRemapConfigDump(t *testing.T) { + tests := []struct { + name string + input string + oldPort string + newPort string + wantErr bool + check func(t *testing.T, output []byte) + }{ + { + name: "full config with ports and env", + input: `{ + "newPortMappings":[{"container_port":5000},{"container_port":7000}], + "env":["FOO=bar","PORT=5000"], + "nested":{"env":["PORT=5000","X=1"]} + }`, + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + + ports := got["newPortMappings"].([]any) + if ports[0].(map[string]any)["container_port"].(float64) != 9999 { + t.Errorf("Expected first port to be 9999") + } + if ports[1].(map[string]any)["container_port"].(float64) != 7000 { + t.Errorf("Expected second port to remain 7000") + } + + env := got["env"].([]any) + if env[1] != "PORT=9999" { + t.Errorf("Expected PORT env to be remapped") + } + + nestedEnv := got["nested"].(map[string]any)["env"].([]any) + if nestedEnv[0] != "PORT=9999" { + t.Errorf("Expected nested PORT env to be remapped") + } + }, + }, + { + name: "empty config", + input: "{}", + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + }, + }, + { + name: "invalid JSON", + input: "{invalid json", + oldPort: "5000", + newPort: "9999", + wantErr: true, + check: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hdr := &tar.Header{Name: "config.dump", Size: int64(len(tt.input))} + newHdr, out, err := remapConfigDump(hdr, bytes.NewReader([]byte(tt.input)), tt.oldPort, tt.newPort) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if newHdr.Size != int64(len(out)) { + t.Errorf("Expected header size %d, got %d", len(out), newHdr.Size) + } + + if tt.check != nil { + tt.check(t, out) + } + }) + } +} + +func TestRemapSpecDump(t *testing.T) { + tests := []struct { + name string + input string + oldPort string + newPort string + wantErr bool + check func(t *testing.T, output []byte) + }{ + { + name: "basic spec with env", + input: `{"process":{"env":["PORT=5000","FOO=bar"]}}`, + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + env := got["process"].(map[string]any)["env"].([]any) + if env[0] != "PORT=9999" { + t.Errorf("Expected PORT to be remapped, got %v", env[0]) + } + if env[1] != "FOO=bar" { + t.Errorf("Expected FOO to remain unchanged") + } + }, + }, + { + name: "spec without process", + input: `{"other":"field"}`, + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + if got["other"] != "field" { + t.Errorf("Expected spec to remain unchanged") + } + }, + }, + { + name: "spec with process but no env", + input: `{"process":{"user":{"uid":0}}}`, + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + process := got["process"].(map[string]any) + if process["user"].(map[string]any)["uid"].(float64) != 0 { + t.Errorf("Expected process to remain unchanged") + } + }, + }, + { + name: "spec with empty env array", + input: `{"process":{"env":[]}}`, + oldPort: "5000", + newPort: "9999", + wantErr: false, + check: func(t *testing.T, output []byte) { + var got map[string]any + if err := json.Unmarshal(output, &got); err != nil { + t.Fatalf("Failed to unmarshal: %v", err) + } + env := got["process"].(map[string]any)["env"].([]any) + if len(env) != 0 { + t.Errorf("Expected empty env to remain empty") + } + }, + }, + { + name: "invalid JSON", + input: "not json", + oldPort: "5000", + newPort: "9999", + wantErr: true, + check: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + hdr := &tar.Header{Name: "spec.dump", Size: int64(len(tt.input))} + newHdr, out, err := remapSpecDump(hdr, bytes.NewReader([]byte(tt.input)), tt.oldPort, tt.newPort) + + if tt.wantErr { + if err == nil { + t.Errorf("Expected error, got none") + } + return + } + + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + if newHdr.Size != int64(len(out)) { + t.Errorf("Expected header size %d, got %d", len(out), newHdr.Size) + } + + if tt.check != nil { + tt.check(t, out) + } + }) + } +} diff --git a/internal/edit_archive.go b/internal/edit_archive.go new file mode 100644 index 00000000..41d930e1 --- /dev/null +++ b/internal/edit_archive.go @@ -0,0 +1,235 @@ +// SPDX-License-Identifier: Apache-2.0 + +package internal + +import ( + "archive/tar" + "fmt" + "io" + "log" + "os" + "path/filepath" + "strconv" + "strings" + + metadata "github.com/checkpoint-restore/checkpointctl/lib" + "github.com/containers/storage/pkg/archive" +) + +// TcpListenRemap remaps the TCP listen ports specified by the portMapping +// of a checkpoint archive specified by the archivePath and replaces the old +// archive with the new one. +// portMapping uses the format old-port:new-port +func TcpListenRemap(portMapping, archivePath string) error { + ports := strings.SplitN(portMapping, ":", 2) + if len(ports) != 2 { + return fmt.Errorf("invalid parameters: expected oldport:newport") + } + + // Parse ports + oldPort, err := strconv.ParseUint(ports[0], 10, 64) + if err != nil { + return fmt.Errorf("invalid parameters: expected oldport:newport") + } + if oldPort == 0 || oldPort > 65535 { + return fmt.Errorf("old port %d is out of valid range (1-65535)", oldPort) + } + + newPort, err := strconv.ParseUint(ports[1], 10, 64) + if err != nil { + return fmt.Errorf("invalid parameters: expected oldport:newport") + } + if newPort == 0 || newPort > 65535 { + return fmt.Errorf("new port %d is out of valid range (1-65535)", newPort) + } + + oldPortStr := strconv.FormatUint(oldPort, 10) + newPortStr := strconv.FormatUint(newPort, 10) + + // Define modifiers for tar entries that need port remapping + mods := map[string]func(*tar.Header, io.Reader) (*tar.Header, []byte, error){ + // Remap the TCP listen port in the CRIU binary image + "checkpoint/files.img": func(hdr *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + return remapFilesImg(hdr, content, uint32(oldPort), uint32(newPort)) + }, + // Remap port mappings and PORT env var in the Podman container config + "config.dump": func(hdr *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + return remapConfigDump(hdr, content, oldPortStr, newPortStr) + }, + // Remap PORT env var in the OCI runtime spec + "spec.dump": func(hdr *tar.Header, content io.Reader) (*tar.Header, []byte, error) { + return remapSpecDump(hdr, content, oldPortStr, newPortStr) + }, + } + + if err := tarStreamRewrite(archivePath, mods); err != nil { + return err + } + + log.Printf("Successfully remapped port %d -> %d\n", oldPort, newPort) + + return nil +} + +// tarStreamRewrite streams through a (possibly compressed) tar archive at archivePath, +// applies modifications to entries as specified by the mods map, writes the result to a temporary file, +// and atomically replaces the original archive with the modified one. +// The compression type and file permissions of the original archive are preserved. +func tarStreamRewrite(archivePath string, mods map[string]func(*tar.Header, io.Reader) (*tar.Header, []byte, error)) error { + archiveFile, err := os.Open(archivePath) + if err != nil { + return err + } + defer archiveFile.Close() + + // Check if there is a checkpoint directory in the archive file + checkpointDirExists, err := isFileInArchive(archivePath, metadata.CheckpointDirectory, true) + if err != nil { + return err + } + if !checkpointDirExists { + return fmt.Errorf("checkpoint directory is missing in the archive file: %s", archivePath) + } + + // For getting input archive's permissions later + archiveInfo, err := archiveFile.Stat() + if err != nil { + return fmt.Errorf("failed to stat archive: %w", err) + } + + // Detect Compression + b := make([]byte, 10) + n, err := io.ReadFull(archiveFile, b) + if err != nil { + return fmt.Errorf("failed to read archive magic bytes") + } + comp := archive.DetectCompression(b[:n]) + // Seek back so DecompressStream can read from the start + if _, err := archiveFile.Seek(0, 0); err != nil { + return fmt.Errorf("failed to seek archive") + } + + // Decompress the archive into a plan tar stream + tarStream, err := archive.DecompressStream(archiveFile) + if err != nil { + return fmt.Errorf("failed to decompress archive") + } + defer tarStream.Close() + + // Create output file with compression + outFile, err := os.CreateTemp(filepath.Dir(archivePath), ".checkpointctl-edit-*.tar") + if err != nil { + return fmt.Errorf("failed to create output file: %w", err) + } + outputPath := outFile.Name() + + compressor, err := archive.CompressStream(outFile, comp) + if err != nil { + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to create compressor") + } + + tarReader := tar.NewReader(tarStream) + tarWriter := tar.NewWriter(compressor) + matchedEntries := 0 + + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to read tar entry: %w", err) + } + + entryName := normalizeArchivePath(header.Name) + + // Check if this entry has a modifier + if modifier, ok := mods[entryName]; ok { + matchedEntries++ + newHeader, data, err := modifier(header, tarReader) + if err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return err + } + if newHeader != nil { + newHeader.Size = int64(len(data)) + if err := tarWriter.WriteHeader(newHeader); err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to write header for %s: %w", header.Name, err) + } + if _, err := tarWriter.Write(data); err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to write data for %s: %w", header.Name, err) + } + } + } else { + // Copy entry unchanged + if err := tarWriter.WriteHeader(header); err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to write header for %s: %w", header.Name, err) + } + if _, err := io.Copy(tarWriter, tarReader); err != nil { + tarWriter.Close() + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to copy data for %s: %w", header.Name, err) + } + } + } + + if err := tarWriter.Close(); err != nil { + compressor.Close() + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to finalize tar stream: %w", err) + } + if err := compressor.Close(); err != nil { + outFile.Close() + os.Remove(outputPath) + return fmt.Errorf("failed to finalize compressed stream: %w", err) + } + if err := outFile.Close(); err != nil { + os.Remove(outputPath) + return fmt.Errorf("failed to close output file: %w", err) + } + if matchedEntries != len(mods) { + os.Remove(outputPath) + return fmt.Errorf("matching entries not found in archive for requested edit operation") + } + + // Match the output file's permissions to the input archive + if err := os.Chmod(outputPath, archiveInfo.Mode().Perm()); err != nil { + return fmt.Errorf("failed to set output permissions: %w", err) + } + + // Replace the modified archive with the original one + if err := os.Rename(outputPath, archivePath); err != nil { + os.Remove(outputPath) + return fmt.Errorf("failed to replace checkpoint: %w", err) + } + + return nil +} + +func normalizeArchivePath(path string) string { + return strings.TrimPrefix(path, "./") +} diff --git a/test/checkpointctl.bats b/test/checkpointctl.bats index 481f7500..50e31f63 100644 --- a/test/checkpointctl.bats +++ b/test/checkpointctl.bats @@ -1328,3 +1328,110 @@ EOF PATH="$ORIG_PATH" } + +@test "Run checkpointctl edit with no flags" { + checkpointctl edit /does-not-exist + [ "$status" -eq 1 ] + [[ "${output}" == "Error: no edit operation specified; use --tcp-listen-remap" ]] +} + +@test "Run checkpointctl edit with --tcp-listen-remap and invalid port format" { + checkpointctl edit --tcp-listen-remap abc /does-not-exist + [ "$status" -eq 1 ] + [[ "${output}" == "Error: invalid parameters: expected oldport:newport" ]] + + checkpointctl edit --tcp-listen-remap 8080:ab /does-not-exist + [ "$status" -eq 1 ] + [[ "${output}" == "Error: invalid parameters: expected oldport:newport" ]] +} + +@test "Run checkpointctl edit with --tcp-listen-remap and invalid port number" { + checkpointctl edit --tcp-listen-remap -1:8080 /does-not-exist + [ "$status" -eq 1 ] + [[ "$output" == "Error: invalid parameters: expected oldport:newport" ]] + + checkpointctl edit --tcp-listen-remap 8080:65536 /does-not-exist + [ "$status" -eq 1 ] + [[ "$output" == "Error: new port 65536 is out of valid range (1-65535)" ]] +} + +@test "Run checkpointctl edit with --tcp-listen-remap and non existing archive" { + checkpointctl edit --tcp-listen-remap 80:8080 /does-not-exist + [ "$status" -eq 1 ] + [[ "$output" == "Error: open /does-not-exist: no such file or directory" ]] +} + +@test "Run checkpointctl edit with --tcp-listen-remap and empty archive" { + touch "$TEST_TMP_DIR1"/empty.tar + checkpointctl edit --tcp-listen-remap 80:8080 "$TEST_TMP_DIR1"/empty.tar + [ "$status" -eq 1 ] + [[ ${lines[0]} == *"checkpoint directory is missing in the archive file"* ]] +} + +@test "Run checkpointctl edit with --tcp-listen-remap on uncompressed tar" { + cp data/config.dump \ + data/spec.dump "$TEST_TMP_DIR1" + mkdir "$TEST_TMP_DIR1"/checkpoint + cp test-imgs/pstree.img \ + test-imgs/core-*.img \ + test-imgs/files.img \ + test-imgs/ids-*.img \ + test-imgs/fdinfo-*.img "$TEST_TMP_DIR1"/checkpoint + + ( cd "$TEST_TMP_DIR1" && tar cf "$TEST_TMP_DIR2"/test.tar . ) + checkpointctl edit --tcp-listen-remap 5000:80 "$TEST_TMP_DIR2"/test.tar + [ "$status" -eq 0 ] + [[ ${lines[0]} == *"Successfully remapped port 5000 -> 80"* ]] + + checkpointctl inspect "$TEST_TMP_DIR2"/test.tar --sockets + [[ "$status" -eq 0 ]] + found=0 + for line in "${lines[@]}"; do + if [[ "$line" == *"[TCP (LISTEN)]"* && "$line" == *"0.0.0.0:80"* ]]; then + found=1 + break + fi + done + + [ "$found" -eq 1 ] +} + +@test "Run checkpointctl edit with --tcp-listen-remap on compressed tar" { + cp data/config.dump \ + data/spec.dump "$TEST_TMP_DIR1" + mkdir "$TEST_TMP_DIR1"/checkpoint + cp test-imgs/pstree.img \ + test-imgs/core-*.img \ + test-imgs/files.img \ + test-imgs/ids-*.img \ + test-imgs/fdinfo-*.img "$TEST_TMP_DIR1"/checkpoint + + ( cd "$TEST_TMP_DIR1" && tar czf "$TEST_TMP_DIR2"/test.tar.gz . ) + checkpointctl edit --tcp-listen-remap 5000:80 "$TEST_TMP_DIR2"/test.tar.gz + [ "$status" -eq 0 ] + [[ ${lines[0]} == *"Successfully remapped port 5000 -> 80"* ]] + + checkpointctl inspect "$TEST_TMP_DIR2"/test.tar.gz --sockets + [[ "$status" -eq 0 ]] + found=0 + for line in "${lines[@]}"; do + if [[ "$line" == *"[TCP (LISTEN)]"* && "$line" == *"0.0.0.0:80"* ]]; then + found=1 + break + fi + done + + [ "$found" -eq 1 ] +} + +@test "Run checkpointctl edit with --tcp-listen-remap with non matching port" { + cp data/config.dump \ + data/spec.dump "$TEST_TMP_DIR1" + mkdir "$TEST_TMP_DIR1"/checkpoint + cp test-imgs/files.img "$TEST_TMP_DIR1"/checkpoint + + ( cd "$TEST_TMP_DIR1" && tar czf "$TEST_TMP_DIR2"/test.tar.gz . ) + checkpointctl edit --tcp-listen-remap 8080:80 "$TEST_TMP_DIR2"/test.tar.gz + [ "$status" -eq 1 ] + [[ ${lines[0]} == *"no TCP listen sockets found with source port 8080"* ]] +}