Skip to content

Commit 987e2de

Browse files
author
kaanyalti
committed
feature(4890): updated the function signature of the upgrade command, updated tests, added new tests
1 parent bf299c7 commit 987e2de

File tree

2 files changed

+281
-21
lines changed

2 files changed

+281
-21
lines changed

internal/pkg/agent/cmd/upgrade.go

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ func newUpgradeCommandWithArgs(_ []string, streams *cli.IOStreams) *cobra.Comman
4242
Long: "This command upgrades the currently installed Elastic Agent to the specified version.",
4343
Args: cobra.ExactArgs(1),
4444
Run: func(c *cobra.Command, args []string) {
45+
c.SetContext(context.Background())
4546
if err := upgradeCmd(streams, c, args); err != nil {
4647
fmt.Fprintf(streams.Err, "Error: %v\n%s\n", err, troubleshootMessage())
4748
os.Exit(1)
@@ -61,27 +62,40 @@ func newUpgradeCommandWithArgs(_ []string, streams *cli.IOStreams) *cobra.Comman
6162
return cmd
6263
}
6364

64-
func upgradeCmd(streams *cli.IOStreams, cmd *cobra.Command, args []string) error {
65-
c := client.New()
66-
return upgradeCmdWithClient(streams, cmd, args, c)
65+
type upgradeInput struct {
66+
streams *cli.IOStreams
67+
cmd *cobra.Command
68+
args []string
69+
c client.Client
70+
agentInfo info.Agent
71+
cFunc confirmFunc
6772
}
6873

69-
func shouldUpgrade(ctx context.Context, cmd *cobra.Command) (bool, error) {
70-
agentInfo, err := info.NewAgentInfoWithLog(ctx, "error", false)
74+
type confirmFunc func(string, bool) (bool, error)
75+
76+
func upgradeCmd(streams *cli.IOStreams, cmd *cobra.Command, args []string) error {
77+
c := client.New()
78+
agentInfo, err := info.NewAgentInfoWithLog(cmd.Context(), "error", false)
7179
if err != nil {
72-
return false, fmt.Errorf("failed to retrieve agent info while tring to upgrade the agent: %w", err)
80+
return fmt.Errorf("failed to retrieve agent info while tring to upgrade the agent: %w", err)
81+
}
82+
input := &upgradeInput{
83+
streams,
84+
cmd,
85+
args,
86+
c,
87+
agentInfo,
88+
cli.Confirm,
7389
}
90+
return upgradeCmdWithClient(input)
91+
}
7492

93+
func shouldUpgrade(cmd *cobra.Command, agentInfo info.Agent, cFunc confirmFunc) (bool, error) {
7594
if agentInfo.IsStandalone() {
7695
return true, nil
7796
}
7897

79-
isAdmin, err := utils.HasRoot()
80-
if err != nil {
81-
return false, fmt.Errorf("failed checking root/Administrator rights while trying to upgrade the agent: %w", err)
82-
}
83-
84-
if !isAdmin {
98+
if agentInfo.Unprivileged() {
8599
return false, fmt.Errorf("upgrade command needs to be executed as root for fleet managed agents")
86100
}
87101

@@ -94,7 +108,7 @@ func shouldUpgrade(ctx context.Context, cmd *cobra.Command) (bool, error) {
94108
return false, fmt.Errorf("upgrading fleet managed agents is not supported")
95109
}
96110

97-
cf, err := cli.Confirm("Upgrading fleet managed agents is not supported. Would you still like to proceed?", false)
111+
cf, err := cFunc("Upgrading fleet managed agents is not supported. Would you still like to proceed?", false)
98112
if err != nil {
99113
return false, fmt.Errorf("failed while confirming action: %w", err)
100114
}
@@ -106,18 +120,18 @@ func shouldUpgrade(ctx context.Context, cmd *cobra.Command) (bool, error) {
106120
return true, nil
107121
}
108122

109-
func upgradeCmdWithClient(streams *cli.IOStreams, cmd *cobra.Command, args []string, c client.Client) error {
110-
version := args[0]
123+
func upgradeCmdWithClient(input *upgradeInput) error {
124+
cmd := input.cmd
125+
c := input.c
126+
version := input.args[0]
111127
sourceURI, _ := cmd.Flags().GetString(flagSourceURI)
112128

113-
ctx := context.Background()
114-
115-
su, err := shouldUpgrade(ctx, cmd)
129+
su, err := shouldUpgrade(cmd, input.agentInfo, input.cFunc)
116130
if !su {
117131
return fmt.Errorf("aborting upgrade: %w", err)
118132
}
119133

120-
err = c.Connect(ctx)
134+
err = c.Connect(cmd.Context())
121135
if err != nil {
122136
return errors.New(err, "Failed communicating to running daemon", errors.TypeNetwork, errors.M("socket", control.Address()))
123137
}
@@ -173,6 +187,6 @@ func upgradeCmdWithClient(streams *cli.IOStreams, cmd *cobra.Command, args []str
173187
return errors.New(err, "Failed trigger upgrade of daemon")
174188
}
175189
}
176-
fmt.Fprintf(streams.Out, "Upgrade triggered to version %s, Elastic Agent is currently restarting\n", version)
190+
fmt.Fprintf(input.streams.Out, "Upgrade triggered to version %s, Elastic Agent is currently restarting\n", version)
177191
return nil
178192
}

internal/pkg/agent/cmd/upgrade_test.go

Lines changed: 247 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/elastic/elastic-agent/internal/pkg/cli"
1919
"github.com/elastic/elastic-agent/pkg/control/v2/client"
2020
"github.com/elastic/elastic-agent/pkg/control/v2/cproto"
21+
mockinfo "github.com/elastic/elastic-agent/testing/mocks/internal_/pkg/agent/application/info"
2122
)
2223

2324
func TestUpgradeCmd(t *testing.T) {
@@ -31,6 +32,8 @@ func TestUpgradeCmd(t *testing.T) {
3132

3233
upgradeCh := make(chan struct{})
3334
mock := &mockServer{upgradeStop: upgradeCh}
35+
mockAgentInfo := mockinfo.NewAgent(t)
36+
mockAgentInfo.On("IsStandalone").Return(true)
3437
cproto.RegisterElasticAgentControlServer(s, mock)
3538
go func() {
3639
err := s.Serve(tcpServer)
@@ -43,10 +46,20 @@ func TestUpgradeCmd(t *testing.T) {
4346
args := []string{"--skip-verify", "8.13.0"}
4447
streams := cli.NewIOStreams()
4548
cmd := newUpgradeCommandWithArgs(args, streams)
49+
cmd.SetContext(context.Background())
50+
51+
commandInput := &upgradeInput{
52+
streams,
53+
cmd,
54+
args,
55+
c,
56+
mockAgentInfo,
57+
nil,
58+
}
4659

4760
// the upgrade command will hang until the server shut down
4861
go func() {
49-
err = upgradeCmdWithClient(streams, cmd, args, c)
62+
err = upgradeCmdWithClient(commandInput)
5063
assert.NoError(t, err)
5164
// verify that we actually talked to the server
5265
counter := atomic.LoadInt32(&mock.upgrades)
@@ -68,6 +81,239 @@ func TestUpgradeCmd(t *testing.T) {
6881
// this makes sure all client assertions are done
6982
<-clientCh
7083
})
84+
t.Run("fail if fleet managed and unprivileged", func(t *testing.T) {
85+
// Set up mock TCP server for gRPC connection
86+
tcpServer, err := net.Listen("tcp", "127.0.0.1:")
87+
require.NoError(t, err)
88+
defer tcpServer.Close()
89+
90+
s := grpc.NewServer()
91+
defer s.Stop()
92+
93+
// Define mock server and agent information
94+
upgradeCh := make(chan struct{})
95+
mock := &mockServer{upgradeStop: upgradeCh}
96+
mockAgentInfo := mockinfo.NewAgent(t)
97+
mockAgentInfo.On("IsStandalone").Return(false) // Simulate fleet-managed agent
98+
mockAgentInfo.On("Unprivileged").Return(true) // Simulate unprivileged mode
99+
cproto.RegisterElasticAgentControlServer(s, mock)
100+
101+
go func() {
102+
err := s.Serve(tcpServer)
103+
assert.NoError(t, err)
104+
}()
105+
106+
// Create client and command
107+
c := client.New(client.WithAddress("http://" + tcpServer.Addr().String()))
108+
args := []string{"8.13.0"} // Version argument
109+
streams := cli.NewIOStreams()
110+
cmd := newUpgradeCommandWithArgs(args, streams)
111+
cmd.SetContext(context.Background())
112+
113+
commandInput := &upgradeInput{
114+
streams,
115+
cmd,
116+
args,
117+
c,
118+
mockAgentInfo,
119+
nil,
120+
}
121+
122+
clientCh := make(chan struct{})
123+
124+
// Execute upgrade command and validate shouldUpgrade error
125+
go func() {
126+
err = upgradeCmdWithClient(commandInput)
127+
128+
// Expect an error due to unprivileged fleet-managed mode
129+
assert.Error(t, err)
130+
assert.Contains(t, err.Error(), "upgrade command needs to be executed as root for fleet managed agents")
131+
132+
// Verify counter has not incremented since upgrade should not proceed
133+
counter := atomic.LoadInt32(&mock.upgrades)
134+
assert.Equal(t, int32(0), counter, "server should not have handled any upgrades")
135+
136+
close(clientCh)
137+
}()
138+
139+
<-clientCh // Ensure goroutine completes before ending test
140+
})
141+
142+
t.Run("fail if fleet managed privileged but no force flag", func(t *testing.T) {
143+
// Set up mock TCP server for gRPC connection
144+
tcpServer, err := net.Listen("tcp", "127.0.0.1:")
145+
require.NoError(t, err)
146+
defer tcpServer.Close()
147+
148+
s := grpc.NewServer()
149+
defer s.Stop()
150+
151+
// Define mock server and agent information
152+
mock := &mockServer{}
153+
mockAgentInfo := mockinfo.NewAgent(t)
154+
mockAgentInfo.On("IsStandalone").Return(false) // Simulate fleet-managed agent
155+
mockAgentInfo.On("Unprivileged").Return(false) // Simulate privileged mode
156+
cproto.RegisterElasticAgentControlServer(s, mock)
157+
158+
go func() {
159+
err := s.Serve(tcpServer)
160+
assert.NoError(t, err)
161+
}()
162+
163+
// Create client and command
164+
c := client.New(client.WithAddress("http://" + tcpServer.Addr().String()))
165+
args := []string{"8.13.0"} // Version argument
166+
streams := cli.NewIOStreams()
167+
cmd := newUpgradeCommandWithArgs(args, streams)
168+
cmd.SetContext(context.Background())
169+
170+
commandInput := &upgradeInput{
171+
streams,
172+
cmd,
173+
args,
174+
c,
175+
mockAgentInfo,
176+
nil,
177+
}
178+
179+
clientCh := make(chan struct{})
180+
181+
// Execute upgrade command and validate shouldUpgrade error
182+
go func() {
183+
err = upgradeCmdWithClient(commandInput)
184+
185+
// Expect an error due to unprivileged fleet-managed mode
186+
assert.Error(t, err)
187+
assert.Contains(t, err.Error(), "upgrading fleet managed agents is not supported")
188+
189+
// Verify counter has not incremented since upgrade should not proceed
190+
counter := atomic.LoadInt32(&mock.upgrades)
191+
assert.Equal(t, int32(0), counter, "server should not have handled any upgrades")
192+
193+
close(clientCh)
194+
}()
195+
196+
<-clientCh // Ensure goroutine completes before ending test
197+
})
198+
t.Run("abort upgrade if fleet managed, privileged, --force is set, and user does not confirm", func(t *testing.T) {
199+
// Set up mock TCP server for gRPC connection
200+
tcpServer, err := net.Listen("tcp", "127.0.0.1:")
201+
require.NoError(t, err)
202+
defer tcpServer.Close()
203+
204+
s := grpc.NewServer()
205+
defer s.Stop()
206+
207+
// Define mock server and agent information
208+
mock := &mockServer{}
209+
mockAgentInfo := mockinfo.NewAgent(t)
210+
mockAgentInfo.On("IsStandalone").Return(false) // Simulate fleet-managed agent
211+
mockAgentInfo.On("Unprivileged").Return(false) // Simulate privileged mode
212+
cproto.RegisterElasticAgentControlServer(s, mock)
213+
214+
go func() {
215+
err := s.Serve(tcpServer)
216+
assert.NoError(t, err)
217+
}()
218+
219+
// Create client and command
220+
c := client.New(client.WithAddress("http://" + tcpServer.Addr().String()))
221+
args := []string{"8.13.0"} // Version argument
222+
streams := cli.NewIOStreams()
223+
cmd := newUpgradeCommandWithArgs(args, streams)
224+
cmd.SetContext(context.Background())
225+
cmd.Flags().Set("force", "true")
226+
227+
commandInput := &upgradeInput{
228+
streams,
229+
cmd,
230+
args,
231+
c,
232+
mockAgentInfo,
233+
func(s string, b bool) (bool, error) {
234+
return false, nil
235+
},
236+
}
237+
238+
clientCh := make(chan struct{})
239+
240+
// Execute upgrade command and validate shouldUpgrade error
241+
go func() {
242+
err = upgradeCmdWithClient(commandInput)
243+
244+
// Expect an error because user does not confirm the upgrade
245+
assert.Error(t, err)
246+
assert.Contains(t, err.Error(), "upgrade not confirmed")
247+
248+
// Verify counter has not incremented since upgrade should not proceed
249+
counter := atomic.LoadInt32(&mock.upgrades)
250+
assert.Equal(t, int32(0), counter, "server should not have handled any upgrades")
251+
252+
close(clientCh)
253+
}()
254+
255+
<-clientCh // Ensure goroutine completes before ending test
256+
})
257+
t.Run("proceed with upgrade if fleet managed, privileged, --force is set, and user confirms upgrade", func(t *testing.T) {
258+
// Set up mock TCP server for gRPC connection
259+
tcpServer, err := net.Listen("tcp", "127.0.0.1:")
260+
require.NoError(t, err)
261+
defer tcpServer.Close()
262+
263+
s := grpc.NewServer()
264+
defer s.Stop()
265+
266+
// Define mock server and agent information
267+
upgradeCh := make(chan struct{})
268+
mock := &mockServer{upgradeStop: upgradeCh}
269+
mockAgentInfo := mockinfo.NewAgent(t)
270+
mockAgentInfo.On("IsStandalone").Return(false) // Simulate fleet-managed agent
271+
mockAgentInfo.On("Unprivileged").Return(false) // Simulate privileged mode
272+
cproto.RegisterElasticAgentControlServer(s, mock)
273+
274+
go func() {
275+
err := s.Serve(tcpServer)
276+
assert.NoError(t, err)
277+
}()
278+
279+
// Create client and command
280+
c := client.New(client.WithAddress("http://" + tcpServer.Addr().String()))
281+
args := []string{"8.13.0"} // Version argument
282+
streams := cli.NewIOStreams()
283+
cmd := newUpgradeCommandWithArgs(args, streams)
284+
cmd.SetContext(context.Background())
285+
cmd.Flags().Set("force", "true")
286+
287+
commandInput := &upgradeInput{
288+
streams,
289+
cmd,
290+
args,
291+
c,
292+
mockAgentInfo,
293+
func(s string, b bool) (bool, error) {
294+
return true, nil
295+
},
296+
}
297+
298+
clientCh := make(chan struct{})
299+
300+
// Execute upgrade command and validate shouldUpgrade error
301+
go func() {
302+
err = upgradeCmdWithClient(commandInput)
303+
304+
assert.NoError(t, err)
305+
306+
// Verify counter has not incremented since upgrade should not proceed
307+
counter := atomic.LoadInt32(&mock.upgrades)
308+
assert.Equal(t, int32(1), counter, "server should handle exactly one upgrade")
309+
310+
close(clientCh)
311+
}()
312+
313+
close(upgradeCh)
314+
315+
<-clientCh // Ensure goroutine completes before ending test
316+
})
71317
}
72318

73319
type mockServer struct {

0 commit comments

Comments
 (0)