Skip to content

Commit 7370355

Browse files
authored
refactor(cli)!: improve target handle logic (#32)
Optimize the relevant code to make it easier to maintain. BREAKING CHANGE: 1. `version.json` => `versions.json` 2. `dataImg` => `data_img` in cli versions params Signed-off-by: Black-Hole1 <[email protected]>
1 parent 8921092 commit 7370355

File tree

4 files changed

+219
-212
lines changed

4 files changed

+219
-212
lines changed

pkg/cli/copy.go

Lines changed: 0 additions & 128 deletions
This file was deleted.

pkg/cli/setup.go

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ import (
1919

2020
type Context struct {
2121
Name string
22-
VersionPath string
22+
VersionsPath string
2323
LogPath string
2424
SocketPath string
2525
IsCliMode bool
@@ -221,22 +221,20 @@ func (c *Context) target() error {
221221
return err
222222
}
223223

224-
c.VersionPath = path.Join(c.TargetPath, "version.json")
224+
c.VersionsPath = path.Join(c.TargetPath, "versions.json")
225225
c.KernelPath = path.Join(c.TargetPath, filepath.Base(kernelPath))
226226
c.InitrdPath = path.Join(c.TargetPath, filepath.Base(initrdPath))
227227
c.RootfsPath = path.Join(c.TargetPath, filepath.Base(rootfsPath))
228228
c.DiskDataPath = path.Join(c.TargetPath, "data.img")
229229
c.DiskTmpPath = path.Join(c.TargetPath, "tmp.img")
230230

231-
{
232-
v := newVersion(c.TargetPath, c.VersionPath, c.DiskDataPath)
233-
if err := v.parseWithCmd(); err != nil {
234-
return err
235-
}
231+
target, err := newTarget(c.TargetPath, kernelPath, initrdPath, rootfsPath, c.DiskDataPath, c.VersionsPath)
232+
if err != nil {
233+
return err
234+
}
236235

237-
if err := v.copy(); err != nil {
238-
return err
239-
}
236+
if err := target.handle(); err != nil {
237+
return err
240238
}
241239

242240
if _, err := os.Stat(c.DiskTmpPath); err != nil {

pkg/cli/target.go

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// SPDX-FileCopyrightText: 2024 OOMOL, Inc. <https://www.oomol.com>
2+
// SPDX-License-Identifier: MPL-2.0
3+
4+
package cli
5+
6+
import (
7+
"encoding/json"
8+
"fmt"
9+
"os"
10+
"path"
11+
"path/filepath"
12+
"strings"
13+
14+
"github.com/oomol-lab/ovm/pkg/utils"
15+
"golang.org/x/sync/errgroup"
16+
)
17+
18+
type versionsJSON struct {
19+
Kernel string `json:"kernel"`
20+
Initrd string `json:"initrd"`
21+
Rootfs string `json:"rootfs"`
22+
DataImg string `json:"data_img"`
23+
24+
path string
25+
needUpdateJSON bool
26+
}
27+
28+
func newVersionsJSON(path string) (*versionsJSON, error) {
29+
v := &versionsJSON{
30+
path: path,
31+
}
32+
33+
if err := parseVersions(); err != nil {
34+
return nil, err
35+
}
36+
37+
if err := v.read(); err != nil {
38+
return nil, err
39+
}
40+
41+
return v, nil
42+
}
43+
44+
// read reads the versions file.
45+
// If parsing fails, the file will be deleted.
46+
func (v *versionsJSON) read() error {
47+
data, err := os.ReadFile(v.path)
48+
if err != nil {
49+
return os.RemoveAll(v.path)
50+
}
51+
52+
if err := json.Unmarshal(data, &v); err != nil {
53+
return os.RemoveAll(v.path)
54+
}
55+
56+
return nil
57+
}
58+
59+
func (v *versionsJSON) saveToDisk() error {
60+
if !v.needUpdateJSON {
61+
return nil
62+
}
63+
64+
data, err := json.Marshal(v)
65+
if err != nil {
66+
return err
67+
}
68+
69+
return os.WriteFile(v.path, data, 0644)
70+
}
71+
72+
func (v *versionsJSON) get(key string) string {
73+
switch key {
74+
case "kernel":
75+
return v.Kernel
76+
case "initrd":
77+
return v.Initrd
78+
case "rootfs":
79+
return v.Rootfs
80+
case "data_img":
81+
return v.DataImg
82+
default:
83+
return ""
84+
}
85+
}
86+
87+
func (v *versionsJSON) set(key, val string) {
88+
var vK *string
89+
switch key {
90+
case "kernel":
91+
vK = &v.Kernel
92+
case "initrd":
93+
vK = &v.Initrd
94+
case "rootfs":
95+
vK = &v.Rootfs
96+
case "data_img":
97+
vK = &v.DataImg
98+
}
99+
100+
if *vK != val {
101+
*vK = val
102+
v.needUpdateJSON = true
103+
}
104+
}
105+
106+
type srcPath struct {
107+
key string
108+
p string
109+
}
110+
111+
type targetContext struct {
112+
targetPath string
113+
114+
srcPaths []srcPath
115+
116+
versionsJSON *versionsJSON
117+
}
118+
119+
func newTarget(targetPath, kernelPath, initrdPath, rootfsPath, dataImgPath, versionsPath string) (*targetContext, error) {
120+
versionsJSON, err := newVersionsJSON(versionsPath)
121+
if err != nil {
122+
return nil, err
123+
}
124+
125+
return &targetContext{
126+
targetPath: targetPath,
127+
srcPaths: []srcPath{
128+
{"kernel", kernelPath},
129+
{"initrd", initrdPath},
130+
{"rootfs", rootfsPath},
131+
{"data_img", dataImgPath},
132+
},
133+
134+
versionsJSON: versionsJSON,
135+
}, nil
136+
}
137+
138+
func (t *targetContext) handle() error {
139+
g := errgroup.Group{}
140+
141+
for _, src := range t.srcPaths {
142+
distPath := path.Join(t.targetPath, filepath.Base(src.p))
143+
144+
if exists, _ := utils.PathExists(distPath); !exists {
145+
t.copyOrCreate(src, &g)
146+
continue
147+
}
148+
149+
if v := t.versionsJSON.get(src.key); v != versionsParams[src.key] {
150+
t.copyOrCreate(src, &g)
151+
continue
152+
}
153+
}
154+
155+
if err := g.Wait(); err != nil {
156+
return err
157+
}
158+
159+
return t.versionsJSON.saveToDisk()
160+
}
161+
162+
func (t *targetContext) copyOrCreate(src srcPath, g *errgroup.Group) {
163+
t.versionsJSON.set(src.key, versionsParams[src.key])
164+
distPath := path.Join(t.targetPath, filepath.Base(src.p))
165+
166+
g.Go(func() error {
167+
if src.key == "data_img" {
168+
if err := os.RemoveAll(distPath); err != nil {
169+
return err
170+
}
171+
172+
return utils.CreateSparseFile(distPath, 8*1024*1024*1024*1024)
173+
}
174+
175+
return utils.Copy(src.p, distPath)
176+
})
177+
}
178+
179+
var versionsParams = map[string]string{
180+
"kernel": "",
181+
"initrd": "",
182+
"rootfs": "",
183+
"data_img": "",
184+
}
185+
186+
func parseVersions() error {
187+
s := strings.Split(versions, ",")
188+
189+
for _, val := range s {
190+
item := strings.Split(strings.TrimSpace(val), "=")
191+
if len(item) != 2 {
192+
continue
193+
}
194+
195+
key := strings.TrimSpace(item[0])
196+
197+
if _, ok := versionsParams[key]; !ok {
198+
continue
199+
}
200+
201+
versionsParams[key] = strings.TrimSpace(item[1])
202+
}
203+
204+
for name, v := range versionsParams {
205+
if v == "" {
206+
return fmt.Errorf("need %s in versions", name)
207+
}
208+
}
209+
210+
return nil
211+
}

0 commit comments

Comments
 (0)