diff --git a/internal/server/ironhawk/iothread.go b/internal/server/ironhawk/iothread.go index 256a56df2..a39646551 100644 --- a/internal/server/ironhawk/iothread.go +++ b/internal/server/ironhawk/iothread.go @@ -5,9 +5,11 @@ package ironhawk import ( "context" + "github.com/dicedb/dicedb-go" "log/slog" "strings" + "github.com/dicedb/dice/config" "github.com/dicedb/dice/internal/auth" "github.com/dicedb/dice/internal/cmd" "github.com/dicedb/dice/internal/shardmanager" @@ -15,29 +17,41 @@ import ( ) type IOThread struct { - ClientID string - Mode string - IoHandler *IOHandler - Session *auth.Session + ClientID string + Mode string + Session *auth.Session + serverWire *dicedb.ServerWire } func NewIOThread(clientFD int) (*IOThread, error) { - io, err := NewIOHandler(clientFD) + w, err := dicedb.NewServerWire(config.MaxRequestSize, config.KeepAlive, clientFD) if err != nil { - slog.Error("Failed to create new IOHandler for clientFD", slog.Int("client-fd", clientFD), slog.Any("error", err)) - return nil, err + if err.Kind == wire.NotEstablished { + slog.Error("failed to establish connection to client", slog.Int("client-fd", clientFD), slog.Any("error", err)) + + return nil, err.Unwrap() + } else { + slog.Error("unexpected error during client connection establishment, this should be reported to DiceDB maintainers", slog.Int("client-fd", clientFD)) + return nil, err.Unwrap() + } } + return &IOThread{ - IoHandler: io, - Session: auth.NewSession(), + serverWire: w, + Session: auth.NewSession(), }, nil } -func (t *IOThread) StartSync(ctx context.Context, shardManager *shardmanager.ShardManager, watchManager *WatchManager) error { +func (t *IOThread) Start(ctx context.Context, shardManager *shardmanager.ShardManager, watchManager *WatchManager) error { for { - c, err := t.IoHandler.ReadSync() - if err != nil { - return err + var c *wire.Command + { + tmpC, err := t.serverWire.Receive() + if err != nil { + return err.Unwrap() + } + + c = tmpC } _c := &cmd.Cmd{ @@ -83,8 +97,8 @@ func (t *IOThread) StartSync(ctx context.Context, shardManager *shardmanager.Sha watchManager.RegisterThread(t) - if err := t.IoHandler.WriteSync(ctx, res.Rs); err != nil { - return err + if sendErr := t.serverWire.Send(ctx, res.Rs); sendErr != nil { + return sendErr.Unwrap() } // TODO: Streamline this because we need ordering of updates @@ -94,6 +108,7 @@ func (t *IOThread) StartSync(ctx context.Context, shardManager *shardmanager.Sha } func (t *IOThread) Stop() error { + t.serverWire.Close() t.Session.Expire() return nil } diff --git a/internal/server/ironhawk/main.go b/internal/server/ironhawk/main.go index 3b602ecf2..5afad0205 100644 --- a/internal/server/ironhawk/main.go +++ b/internal/server/ironhawk/main.go @@ -155,7 +155,7 @@ func (s *Server) AcceptConnectionRequests(ctx context.Context, wg *sync.WaitGrou func (s *Server) startIOThread(ctx context.Context, wg *sync.WaitGroup, thread *IOThread) { wg.Done() - err := thread.StartSync(ctx, s.shardManager, s.watchManager) + err := thread.Start(ctx, s.shardManager, s.watchManager) if err != nil { if err == io.EOF { s.watchManager.CleanupThreadWatchSubscriptions(thread) diff --git a/internal/server/ironhawk/netconn.go b/internal/server/ironhawk/netconn.go deleted file mode 100644 index 4ed11befc..000000000 --- a/internal/server/ironhawk/netconn.go +++ /dev/null @@ -1,153 +0,0 @@ -// Copyright (c) 2022-present, DiceDB contributors -// All rights reserved. Licensed under the BSD 3-Clause License. See LICENSE file in the project root for full license information. - -package ironhawk - -import ( - "bufio" - "context" - "errors" - "fmt" - "io" - "net" - "os" - "time" - - "github.com/dicedb/dice/config" - "github.com/dicedb/dicedb-go/wire" - "google.golang.org/protobuf/proto" -) - -var ( - ErrRequestTooLarge = errors.New("request too large") - ErrIdleTimeout = errors.New("connection idle timeout") - ErrorClosed = errors.New("connection closed") -) - -// IOHandler handles I/O operations for a network connection -type IOHandler struct { - fd int - file *os.File - conn net.Conn -} - -// NewIOHandler creates a new IOHandler from a file descriptor -func NewIOHandler(clientFD int) (*IOHandler, error) { - file := os.NewFile(uintptr(clientFD), "client-connection") - if file == nil { - return nil, fmt.Errorf("failed to create file from file descriptor") - } - - var conn net.Conn - defer func() { - // Only close the file if we haven't successfully created a net.Conn - if conn == nil { - file.Close() - } - }() - - var err error - conn, err = net.FileConn(file) - if err != nil { - return nil, fmt.Errorf("failed to create net.Conn from file descriptor: %w", err) - } - - if tcpConn, ok := conn.(*net.TCPConn); ok { - if err := tcpConn.SetNoDelay(true); err != nil { - return nil, fmt.Errorf("failed to set TCP_NODELAY: %w", err) - } - if err := tcpConn.SetKeepAlive(true); err != nil { - return nil, fmt.Errorf("failed to set keepalive: %w", err) - } - if err := tcpConn.SetKeepAlivePeriod(time.Duration(config.KeepAlive) * time.Second); err != nil { - return nil, fmt.Errorf("failed to set keepalive period: %w", err) - } - } - - return &IOHandler{ - fd: clientFD, - file: file, - conn: conn, - }, nil -} - -func NewIOHandlerWithConn(conn net.Conn) *IOHandler { - return &IOHandler{ - conn: conn, - } -} - -// ReadRequest reads data from the network connection -func (h *IOHandler) Read(ctx context.Context) ([]byte, error) { - return nil, nil -} - -// ReadRequest reads data from the network connection -func (h *IOHandler) ReadSync() (*wire.Command, error) { - var result []byte - reader := bufio.NewReaderSize(h.conn, config.IoBufferSize) - buf := make([]byte, config.IoBufferSize) - - for { - n, err := reader.Read(buf) - if n > 0 { - if len(result)+n > config.MaxRequestSize { - return nil, fmt.Errorf("request too large") - } - - result = append(result, buf[:n]...) - } - if err != nil { - if err == io.EOF { - break - } - return nil, err - } - - if n < len(buf) { - break - } - } - - if len(result) == 0 { - return nil, io.EOF - } - - c := &wire.Command{} - if err := proto.Unmarshal(result, c); err != nil { - return nil, fmt.Errorf("failed to unmarshal command: %w", err) - } - return c, nil -} - -func (h *IOHandler) Write(ctx context.Context, r interface{}) error { - return nil -} - -func (h *IOHandler) WriteSync(ctx context.Context, r *wire.Result) error { - var b []byte - var err error - - if b, err = proto.Marshal(r); err != nil { - return err - } - - if _, err := h.conn.Write(b); err != nil { - return err - } - - return nil -} - -// Close underlying network connection -func (h *IOHandler) Close() error { - var err error - if h.conn != nil { - err = errors.Join(err, h.conn.Close()) - } - if h.file != nil { - err = errors.Join(err, h.file.Close()) - } - - return err -} diff --git a/internal/server/ironhawk/watch_manager.go b/internal/server/ironhawk/watch_manager.go index 3f5acc770..2ebe5e5a6 100644 --- a/internal/server/ironhawk/watch_manager.go +++ b/internal/server/ironhawk/watch_manager.go @@ -162,7 +162,7 @@ func (w *WatchManager) NotifyWatchers(c *cmd.Cmd, shardManager *shardmanager.Sha continue } - err := thread.IoHandler.WriteSync(context.Background(), r.Rs) + err := thread.serverWire.Send(context.Background(), r.Rs) if err != nil { slog.Error("failed to write response to thread", slog.Any("client_id", thread.ClientID),