Skip to content

Commit 4be9cc4

Browse files
committed
refactor: enhance module renaming workflows and add validation
- Introduced `RenameProjectModule` to handle module renaming along with related files. - Added validation for zip file size and enhanced `downloadZipReader` with timeout handling. - Improved file path safety checks and error handling workflows. - Updated related test cases to include comprehensive end-to-end integration testing for renaming and build processes.
1 parent 0a23331 commit 4be9cc4

File tree

5 files changed

+164
-38
lines changed

5 files changed

+164
-38
lines changed

cmd/rename.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ func init() {
2525
if *oldMod == "" || *newMod == "" {
2626
return errors.New("--old and --new are required")
2727
}
28-
return renamer.RenameDirModule(*oldMod, *newMod, *target)
28+
if *oldMod == *newMod {
29+
return errors.New("--old and --new must be different")
30+
}
31+
return renamer.RenameProjectModule(*oldMod, *newMod, *target, []string{
32+
"buf.gen.yaml",
33+
"buf.binding.yaml",
34+
}, true)
2935
}
3036
}

internal/create/create.go

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,13 @@ func initGitRepo(target string) error {
153153

154154
func renameGoModule(oldModName, newModName, target string) error {
155155
log.Printf("rename module: %s -> %s", oldModName, newModName)
156-
err := renamer.RenameDirModule(oldModName, newModName, target)
157-
if err != nil {
158-
return err
159-
}
160-
files := []string{
156+
if err := renamer.RenameProjectModule(oldModName, newModName, target, []string{
161157
"buf.gen.yaml",
162158
"buf.binding.yaml",
159+
}, false); err != nil {
160+
return err
163161
}
164-
for _, file := range files {
165-
e := replaceFileContent(oldModName, newModName, filepath.Join(target, file))
166-
if e != nil {
167-
return e
168-
}
169-
}
170-
err = execCommands(target,
162+
err := execCommands(target,
171163
[]string{"go", "mod", "edit", "-module", newModName},
172164
[]string{"make", "init"},
173165
[]string{"go", "mod", "tidy"},
@@ -198,20 +190,3 @@ func execCommands(dir string, commands ...[]string) error {
198190
}
199191
return nil
200192
}
201-
202-
func replaceFileContent(old, new, filePath string) error {
203-
content, err := os.ReadFile(filePath)
204-
if err != nil {
205-
return err
206-
}
207-
replacer := strings.NewReplacer(old, new)
208-
file, err := os.Create(filePath)
209-
if err != nil {
210-
return err
211-
}
212-
defer func() {
213-
_ = file.Close()
214-
}()
215-
_, err = replacer.WriteString(file, string(content))
216-
return err
217-
}

internal/create/create_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package create
33
import (
44
"flag"
55
"os"
6+
"path/filepath"
67
"testing"
8+
9+
"github.com/go-sphere/sphere-cli/internal/renamer"
710
)
811

912
var createTest = flag.Bool("create_test", false, "run create tests that create files and directories")
@@ -33,3 +36,47 @@ func TestLayout(t *testing.T) {
3336
}
3437
_ = os.RemoveAll("simple")
3538
}
39+
40+
func TestSimpleLayoutCreateAndRenameBuild(t *testing.T) {
41+
if os.Getenv("CI") != "" {
42+
t.Skip("Skipping integration test in CI")
43+
}
44+
if !*createTest {
45+
t.Skip("Skipping integration test, run with -create_test to enable")
46+
}
47+
48+
workspace := t.TempDir()
49+
prevDir, err := os.Getwd()
50+
if err != nil {
51+
t.Fatal(err)
52+
}
53+
if err := os.Chdir(workspace); err != nil {
54+
t.Fatal(err)
55+
}
56+
t.Cleanup(func() {
57+
_ = os.Chdir(prevDir)
58+
})
59+
60+
projectName := "simple-e2e"
61+
oldModule := "github.com/example/simple-e2e"
62+
newModule := "github.com/example/simple-e2e-renamed"
63+
64+
if err := Project(projectName, oldModule, templateLayouts["simple"]); err != nil {
65+
t.Fatal(err)
66+
}
67+
68+
projectDir := filepath.Join(workspace, projectName)
69+
if err := renamer.RenameProjectModule(oldModule, newModule, projectDir, []string{
70+
"buf.gen.yaml",
71+
"buf.binding.yaml",
72+
}, true); err != nil {
73+
t.Fatal(err)
74+
}
75+
76+
if err := execCommands(projectDir,
77+
[]string{"make", "init"},
78+
[]string{"make", "build"},
79+
); err != nil {
80+
t.Fatal(err)
81+
}
82+
}

internal/renamer/renamer.go

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
package renamer
22

33
import (
4+
"errors"
5+
"fmt"
46
"go/ast"
57
"go/parser"
68
"go/printer"
79
"go/token"
10+
"io/fs"
811
"log"
912
"os"
1013
"path/filepath"
1114
"strings"
1215
)
1316

1417
func RenameDirModule(oldModule, newModule string, dir string) error {
15-
err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
16-
if info.IsDir() {
18+
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
19+
if err != nil {
20+
return err
21+
}
22+
if d.IsDir() {
1723
return nil
1824
}
1925
if strings.HasSuffix(path, ".go") {
@@ -24,6 +30,22 @@ func RenameDirModule(oldModule, newModule string, dir string) error {
2430
if err != nil {
2531
return err
2632
}
33+
return renameGoModuleFile(oldModule, newModule, filepath.Join(dir, "go.mod"))
34+
}
35+
36+
func RenameProjectModule(oldModule, newModule, dir string, relatedFiles []string, ignoreMissingFiles bool) error {
37+
if err := RenameDirModule(oldModule, newModule, dir); err != nil {
38+
return err
39+
}
40+
for _, file := range relatedFiles {
41+
filePath := filepath.Join(dir, file)
42+
if err := replaceFileContent(oldModule, newModule, filePath); err != nil {
43+
if ignoreMissingFiles && errors.Is(err, os.ErrNotExist) {
44+
continue
45+
}
46+
return err
47+
}
48+
}
2749
return nil
2850
}
2951

@@ -60,3 +82,47 @@ func RenameModule(oldModule, newModule string, path string) error {
6082
}
6183
return nil
6284
}
85+
86+
func renameGoModuleFile(oldModule, newModule, modPath string) error {
87+
content, err := os.ReadFile(modPath)
88+
if err != nil {
89+
if errors.Is(err, os.ErrNotExist) {
90+
return fmt.Errorf("go.mod not found in target: %s", modPath)
91+
}
92+
return err
93+
}
94+
95+
lines := strings.Split(string(content), "\n")
96+
found := false
97+
for i, line := range lines {
98+
trimmed := strings.TrimSpace(line)
99+
if !strings.HasPrefix(trimmed, "module ") {
100+
continue
101+
}
102+
fields := strings.Fields(trimmed)
103+
if len(fields) < 2 {
104+
return fmt.Errorf("invalid module directive in %s", modPath)
105+
}
106+
currentModule := strings.Trim(fields[1], `"`)
107+
if oldModule != "" && currentModule != oldModule {
108+
return fmt.Errorf("go.mod module mismatch: expected %q, got %q", oldModule, currentModule)
109+
}
110+
lines[i] = "module " + newModule
111+
found = true
112+
break
113+
}
114+
115+
if !found {
116+
return fmt.Errorf("module directive not found in %s", modPath)
117+
}
118+
return os.WriteFile(modPath, []byte(strings.Join(lines, "\n")), 0o644)
119+
}
120+
121+
func replaceFileContent(old, new, filePath string) error {
122+
content, err := os.ReadFile(filePath)
123+
if err != nil {
124+
return err
125+
}
126+
replaced := strings.ReplaceAll(string(content), old, new)
127+
return os.WriteFile(filePath, []byte(replaced), 0o644)
128+
}

internal/zip/zip.go

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,19 @@ import (
99
"os"
1010
"path/filepath"
1111
"strings"
12+
"time"
13+
)
14+
15+
const (
16+
httpTimeout = 90 * time.Second
17+
maxZipSizeBytes = 100 << 20 // 100 MiB
1218
)
1319

1420
func downloadZipReader(url string) (*zip.Reader, func(), error) {
15-
resp, err := http.Get(url)
21+
client := http.Client{
22+
Timeout: httpTimeout,
23+
}
24+
resp, err := client.Get(url)
1625
if err != nil {
1726
return nil, nil, err
1827
}
@@ -22,27 +31,50 @@ func downloadZipReader(url string) (*zip.Reader, func(), error) {
2231
if resp.StatusCode != http.StatusOK {
2332
return nil, nil, errors.New(resp.Status)
2433
}
34+
if resp.ContentLength > maxZipSizeBytes {
35+
return nil, nil, fmt.Errorf("zip file too large: %d bytes", resp.ContentLength)
36+
}
2537
tempFile, err := os.CreateTemp("", "zip-*")
2638
if err != nil {
2739
return nil, nil, err
2840
}
29-
length, err := io.Copy(tempFile, resp.Body)
41+
cleanup := func() {
42+
_ = tempFile.Close()
43+
_ = os.Remove(tempFile.Name())
44+
}
45+
length, err := io.Copy(tempFile, io.LimitReader(resp.Body, maxZipSizeBytes+1))
3046
if err != nil {
47+
cleanup()
3148
return nil, nil, err
3249
}
50+
if length > maxZipSizeBytes {
51+
cleanup()
52+
return nil, nil, fmt.Errorf("zip file too large: exceeded %d bytes", maxZipSizeBytes)
53+
}
3354
reader, err := zip.NewReader(tempFile, length)
3455
if err != nil {
56+
cleanup()
3557
return nil, nil, err
3658
}
3759
return reader, func() {
38-
_ = tempFile.Close()
39-
_ = os.Remove(tempFile.Name())
60+
cleanup()
4061
}, nil
4162
}
4263

4364
func ensureSafePath(tempDir, fileName string) (string, error) {
44-
filePath := filepath.Join(tempDir, fileName)
45-
if !strings.HasPrefix(filePath, filepath.Clean(tempDir)) {
65+
basePath, err := filepath.Abs(filepath.Clean(tempDir))
66+
if err != nil {
67+
return "", err
68+
}
69+
filePath, err := filepath.Abs(filepath.Join(basePath, fileName))
70+
if err != nil {
71+
return "", err
72+
}
73+
relPath, err := filepath.Rel(basePath, filePath)
74+
if err != nil {
75+
return "", err
76+
}
77+
if relPath == ".." || strings.HasPrefix(relPath, ".."+string(os.PathSeparator)) {
4678
return "", fmt.Errorf("unsafe file path: %s", filePath)
4779
}
4880
return filePath, nil

0 commit comments

Comments
 (0)