Skip to content

Commit d9bd671

Browse files
committed
Harden session migration helpers
1 parent 080f0b8 commit d9bd671

File tree

4 files changed

+113
-12
lines changed

4 files changed

+113
-12
lines changed

cmd/internal/migrations/v3/ast_helpers.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ import (
99
)
1010

1111
// parseGoFile parses Go source content into an AST. It returns the parsed file
12-
// and token.FileSet or an error if the content cannot be parsed.
13-
func parseGoFile(content string) (*ast.File, *token.FileSet, error) {
12+
// or an error if the content cannot be parsed.
13+
func parseGoFile(content string) (*ast.File, error) {
1414
fset := token.NewFileSet()
1515
file, err := parser.ParseFile(fset, "", content, parser.ParseComments)
1616
if err != nil {
17-
return nil, nil, fmt.Errorf("parse Go file: %w", err)
17+
return nil, fmt.Errorf("parse Go file: %w", err)
1818
}
19-
return file, fset, nil
19+
return file, nil
2020
}
2121

2222
// collectImportAliases finds all import aliases for the given import path within
@@ -35,6 +35,10 @@ func collectImportAliases(file *ast.File, importPath string) map[string]struct{}
3535
}
3636

3737
if imp.Name != nil {
38+
if imp.Name.Name == "_" || imp.Name.Name == "." {
39+
continue
40+
}
41+
3842
aliases[imp.Name.Name] = struct{}{}
3943
continue
4044
}
@@ -53,7 +57,7 @@ func collectAssignedCallIdents(file *ast.File, predicate func(*ast.CallExpr) boo
5357

5458
ast.Inspect(file, func(n ast.Node) bool {
5559
assign, ok := n.(*ast.AssignStmt)
56-
if !ok || len(assign.Lhs) != len(assign.Rhs) {
60+
if !ok {
5761
return true
5862
}
5963

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package v3
2+
3+
import (
4+
"go/ast"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func Test_parseGoFile_InvalidContent(t *testing.T) {
12+
t.Parallel()
13+
14+
_, err := parseGoFile("package main\n func")
15+
assert.Error(t, err)
16+
}
17+
18+
func Test_collectImportAliases(t *testing.T) {
19+
t.Parallel()
20+
21+
tests := map[string]struct { //nolint:govet // fieldalignment warning is not relevant for test data shapes
22+
content string
23+
importPath string
24+
expected map[string]struct{}
25+
}{
26+
"default alias": {
27+
importPath: "github.com/gofiber/fiber/v3/middleware/session",
28+
content: "package main\nimport \"github.com/gofiber/fiber/v3/middleware/session\"\n",
29+
expected: map[string]struct{}{"session": {}},
30+
},
31+
"explicit alias": {
32+
importPath: "github.com/gofiber/fiber/v3/middleware/session",
33+
content: "package main\nimport sess \"github.com/gofiber/fiber/v3/middleware/session\"\n",
34+
expected: map[string]struct{}{"sess": {}},
35+
},
36+
"blank import ignored": {
37+
importPath: "github.com/gofiber/fiber/v3/middleware/session",
38+
content: "package main\nimport _ \"github.com/gofiber/fiber/v3/middleware/session\"\n",
39+
expected: map[string]struct{}{},
40+
},
41+
"dot import ignored": {
42+
importPath: "github.com/gofiber/fiber/v3/middleware/session",
43+
content: "package main\nimport . \"github.com/gofiber/fiber/v3/middleware/session\"\n",
44+
expected: map[string]struct{}{},
45+
},
46+
"unrelated import": {
47+
importPath: "github.com/gofiber/fiber/v3/middleware/session",
48+
content: "package main\nimport \"github.com/example/other\"\n",
49+
expected: map[string]struct{}{},
50+
},
51+
}
52+
53+
for name, tt := range tests {
54+
t.Run(name, func(t *testing.T) {
55+
t.Parallel()
56+
57+
file, err := parseGoFile(tt.content)
58+
require.NoError(t, err)
59+
60+
aliases := collectImportAliases(file, tt.importPath)
61+
assert.Equal(t, tt.expected, aliases)
62+
})
63+
}
64+
}
65+
66+
func Test_collectAssignedCallIdents(t *testing.T) {
67+
t.Parallel()
68+
69+
content := `package main
70+
71+
func target() (int, error) { return 0, nil }
72+
func other() int { return 1 }
73+
74+
func main() {
75+
primary, secondary := target()
76+
single := target()
77+
_, ignored := target()
78+
value, err := other()
79+
field.Name = target()
80+
}
81+
`
82+
83+
file, err := parseGoFile(content)
84+
require.NoError(t, err)
85+
86+
matches := collectAssignedCallIdents(file, func(call *ast.CallExpr) bool {
87+
if ident, ok := call.Fun.(*ast.Ident); ok {
88+
return ident.Name == "target"
89+
}
90+
return false
91+
})
92+
93+
assert.Contains(t, matches, "primary")
94+
assert.Contains(t, matches, "single")
95+
assert.NotContains(t, matches, "value")
96+
assert.NotContains(t, matches, "ignored")
97+
}

cmd/internal/migrations/v3/session_release.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version)
2828
reStoreGet := regexp.MustCompile(`(?m)^(\s*)(\w+),\s*(\w+)\s*:=\s*(\w+)\.(Get(?:ByID)?)\(`)
2929

3030
changed, err := internal.ChangeFileContent(cwd, func(content string) string {
31-
file, _, err := parseGoFile(content)
31+
file, err := parseGoFile(content)
3232
if err != nil {
3333
return content
3434
}
@@ -53,7 +53,7 @@ func MigrateSessionRelease(cmd *cobra.Command, cwd string, _, _ *semver.Version)
5353
return false
5454
}
5555

56-
return sel.Sel.Name == "New" || strings.HasPrefix(sel.Sel.Name, "NewStore")
56+
return sel.Sel.Name == "New"
5757
})
5858

5959
if len(storeVars) == 0 {

cmd/internal/migrations/v3/session_release_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626
)
2727
2828
func handler(c fiber.Ctx) error {
29-
store := session.NewStore()
29+
store := session.New()
3030
sess, err := store.Get(c)
3131
if err != nil {
3232
return err
@@ -66,7 +66,7 @@ import (
6666
)
6767
6868
func handler(c fiber.Ctx) error {
69-
store := session.NewStore()
69+
store := session.New()
7070
sess, err := store.Get(c)
7171
if err != nil {
7272
return err
@@ -110,7 +110,7 @@ import (
110110
)
111111
112112
func backgroundTask(sessionID string) {
113-
store := session.NewStore()
113+
store := session.New()
114114
sess, err := store.GetByID(context.Background(), sessionID)
115115
if err != nil {
116116
return
@@ -150,7 +150,7 @@ import (
150150
)
151151
152152
func handler(c fiber.Ctx) error {
153-
store := session.NewStore()
153+
store := session.New()
154154
sess, err := store.Get(c)
155155
if err != nil {
156156
c.Status(500)
@@ -233,7 +233,7 @@ import (
233233
)
234234
235235
func handler(c fiber.Ctx) error {
236-
store := alias.NewStore()
236+
store := alias.New()
237237
session, err := store.Get(c)
238238
if err != nil {
239239
return err

0 commit comments

Comments
 (0)