Skip to content

Commit 8bb86b6

Browse files
authored
fix(ipc): ovm crash when client closed connection (#53)
Signed-off-by: Kevin Cui <[email protected]>
1 parent dcd239e commit 8bb86b6

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

pkg/ipc/restful/restful.go

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,23 +162,30 @@ func (s *Restful) mux() *http.ServeMux {
162162
}
163163

164164
outCh := infinity.NewChannel[string]()
165-
errCh := make(chan string)
166-
doneCh := make(chan struct{})
165+
errCh := make(chan string, 1)
166+
doneCh := make(chan struct{}, 1)
167167

168168
go func() {
169-
if err := s.exec(body.Command, outCh, errCh); err != nil {
169+
if err := s.exec(r.Context(), body.Command, outCh, errCh); err != nil {
170170
s.log.Warnf("Failed to execute command: %v", err)
171171
}
172172

173-
_, _ = fmt.Fprintf(w, "event: done\n")
174-
_, _ = fmt.Fprintf(w, "data: done\n\n")
175-
w.(http.Flusher).Flush()
176-
177173
doneCh <- struct{}{}
178174
outCh.Close()
179175
close(errCh)
180176
}()
181177

178+
defer func() {
179+
select {
180+
case <-r.Context().Done():
181+
// pass
182+
default:
183+
_, _ = fmt.Fprintf(w, "event: done\n")
184+
_, _ = fmt.Fprintf(w, "data: done\n\n")
185+
w.(http.Flusher).Flush()
186+
}
187+
}()
188+
182189
for {
183190
select {
184191
case <-doneCh:
@@ -295,7 +302,7 @@ func (s *Restful) powerSaveMode(enable bool) {
295302
s.opt.PowerSaveMode = enable
296303
}
297304

298-
func (s *Restful) exec(command string, outCh *infinity.Channel[string], errCh chan string) error {
305+
func (s *Restful) exec(ctx context.Context, command string, outCh *infinity.Channel[string], errCh chan string) error {
299306
s.log.Info("request /exec")
300307

301308
conf := &ssh.ClientConfig{
@@ -313,6 +320,10 @@ func (s *Restful) exec(command string, outCh *infinity.Channel[string], errCh ch
313320
}
314321
defer conn.Close()
315322

323+
context.AfterFunc(ctx, func() {
324+
_ = conn.Close()
325+
})
326+
316327
session, err := conn.NewSession()
317328
if err != nil {
318329
errCh <- fmt.Sprintf("new ssh session error: %v", err)

0 commit comments

Comments
 (0)