Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions common/fabxhttp/server.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package fabxhttp

import (
"context"
"crypto/tls"
"net"
"net/http"
"os"
"time"

"github.com/hyperledger/fabric-lib-go/common/flogging"

"github.com/hyperledger/fabric-x-common/common/middleware"
"github.com/hyperledger/fabric-x-common/common/util"
)

// Logger defines the logging interface for the HTTP server.
type Logger interface {
Warn(args ...any)
Warnf(template string, args ...any)
}

// Options contains configuration options for the HTTP server.
type Options struct {
Logger Logger
ListenAddress string
TLS TLS
}

// Server represents an HTTP server with TLS support and middleware capabilities.
type Server struct {
logger Logger
options Options
httpServer *http.Server
mux *http.ServeMux
addr string
}

// NewServer creates a new HTTP server with the provided options.
func NewServer(o Options) *Server {
logger := o.Logger
if logger == nil {
logger = flogging.MustGetLogger("fabhttp")
}

server := &Server{
logger: logger,
options: o,
}

server.initializeServer()

return server
}

// Run starts the server and blocks until a signal is received, then stops the server.
func (s *Server) Run(signals <-chan os.Signal, ready chan<- struct{}) error {
err := s.Start()
if err != nil {
return err
}

close(ready)

<-signals
return s.Stop()
}

// Start begins listening and serving HTTP requests in a goroutine.
func (s *Server) Start() error {
listener, err := s.Listen()
if err != nil {
return err
}
s.addr = listener.Addr().String()

go func() {
if err := s.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
s.logger.Warnf("HTTP server stopped with error: %v", err)
}
}()

return nil
}

// Stop gracefully shuts down the server with a timeout.
func (s *Server) Stop() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

return s.httpServer.Shutdown(ctx)
}

func (s *Server) initializeServer() {
s.mux = http.NewServeMux()
s.httpServer = &http.Server{
Addr: s.options.ListenAddress,
Handler: s.mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 2 * time.Minute,
}
}

// SecureHandlerChain wraps the provided handler with middleware including certificate requirement.
func (*Server) SecureHandlerChain(h http.Handler) http.Handler {
return middleware.NewChain(middleware.RequireCert(), middleware.WithRequestID(util.GenerateUUID)).Handler(h)
}

// InsecureHandlerChain wraps the provided handler with basic middleware without certificate requirement.
func (*Server) InsecureHandlerChain(h http.Handler) http.Handler {
return middleware.NewChain(middleware.WithRequestID(util.GenerateUUID)).Handler(h)
}

// RegisterHandler registers into the ServeMux a handler chain that borrows
// its security properties from the fabhttp.Server. This method is thread
// safe because ServeMux.Handle() is thread safe, and options are immutable.
// This method can be called either before or after Server.Start(). If the
// pattern exists the method panics.
//
//nolint:revive // secure parameter is part of the public API
func (s *Server) RegisterHandler(pattern string, handler http.Handler, secure bool) {
var h http.Handler
if secure {
h = s.SecureHandlerChain(handler)
} else {
h = s.InsecureHandlerChain(handler)
}
s.mux.Handle(pattern, h)
}

// Listen creates a network listener with optional TLS configuration.
func (s *Server) Listen() (net.Listener, error) {
listener, err := net.Listen("tcp", s.options.ListenAddress)
if err != nil {
return nil, err
}
tlsConfig, err := s.options.TLS.Config()
if err != nil {
return nil, err
}
if tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
return listener, nil
}

// Addr returns the server's listening address.
func (s *Server) Addr() string {
return s.addr
}

// Log logs a warning message with the provided key-value pairs.
func (s *Server) Log(keyvals ...any) error {
s.logger.Warn(keyvals...)
return nil
}
65 changes: 65 additions & 0 deletions common/fabxhttp/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package fabxhttp

import (
"crypto/tls"
"crypto/x509"
"os"
)

// TLS contains TLS configuration options for the HTTP server.
type TLS struct {
Enabled bool
CertFile string
KeyFile string
ClientCertRequired bool
ClientCACertFiles []string
}

// DefaultTLSCipherSuites contains the strong TLS cipher suites used by default.
var DefaultTLSCipherSuites = []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
}

// Config creates a TLS configuration from the TLS settings.
func (t TLS) Config() (*tls.Config, error) {
var tlsConfig *tls.Config

if t.Enabled {
cert, err := tls.LoadX509KeyPair(t.CertFile, t.KeyFile)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
for _, caPath := range t.ClientCACertFiles {
caPem, err := os.ReadFile(caPath)
if err != nil {
return nil, err
}
caCertPool.AppendCertsFromPEM(caPem)
}
tlsConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
Certificates: []tls.Certificate{cert},
CipherSuites: DefaultTLSCipherSuites,
ClientCAs: caCertPool,
}
if t.ClientCertRequired {
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
} else {
tlsConfig.ClientAuth = tls.VerifyClientCertIfGiven
}
}

return tlsConfig, nil
}
43 changes: 43 additions & 0 deletions common/middleware/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package middleware

import (
"net/http"

"github.com/hyperledger/fabric-lib-go/common/flogging"
)

var logger = flogging.MustGetLogger("middleware")

// Middleware is a function that wraps an http.Handler to provide additional functionality.
type Middleware func(http.Handler) http.Handler

// A Chain is a middleware chain use for http request processing.
type Chain struct {
mw []Middleware
}

// NewChain creates a new Middleware chain. The chain will call the Middleware
// in the order provided.
func NewChain(middlewares ...Middleware) Chain {
return Chain{
mw: append([]Middleware{}, middlewares...),
}
}

// Handler returns an http.Handler for this chain.
func (c Chain) Handler(h http.Handler) http.Handler {
if h == nil {
h = http.DefaultServeMux
}

for i := len(c.mw) - 1; i >= 0; i-- {
h = c.mw[i](h)
}
return h
}
54 changes: 54 additions & 0 deletions common/middleware/request_id.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package middleware

import (
"context"
"net/http"
)

var requestIDKey = requestIDKeyType{}

type requestIDKeyType struct{}

// RequestID extracts the request ID from the context.
func RequestID(ctx context.Context) string {
if reqID, ok := ctx.Value(requestIDKey).(string); ok {
return reqID
}
return "unknown"
}

// GenerateIDFunc is a function that generates a unique request ID.
type GenerateIDFunc func() string

type requestID struct {
generateID GenerateIDFunc
next http.Handler
}

// WithRequestID returns a middleware that adds a request ID to each request.
func WithRequestID(generator GenerateIDFunc) Middleware {
return func(next http.Handler) http.Handler {
return &requestID{next: next, generateID: generator}
}
}

func (r *requestID) ServeHTTP(w http.ResponseWriter, req *http.Request) {
reqID := req.Header.Get("X-Request-Id")
if reqID == "" {
reqID = r.generateID()
req.Header.Set("X-Request-Id", reqID)
}

ctx := context.WithValue(req.Context(), requestIDKey, reqID)
req = req.WithContext(ctx)

w.Header().Add("X-Request-Id", reqID)

r.next.ServeHTTP(w, req)
}
42 changes: 42 additions & 0 deletions common/middleware/require_cert.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
Copyright IBM Corp. All Rights Reserved.

SPDX-License-Identifier: Apache-2.0
*/

package middleware

import (
"net/http"
)

type requireCert struct {
next http.Handler
}

// RequireCert is used to ensure that a verified TLS client certificate was
// used for authentication.
func RequireCert() Middleware {
return func(next http.Handler) http.Handler {
return &requireCert{next: next}
}
}

func (r *requireCert) ServeHTTP(w http.ResponseWriter, req *http.Request) {
switch {
case req.TLS == nil:
fallthrough
case len(req.TLS.VerifiedChains) == 0:
fallthrough
case len(req.TLS.VerifiedChains[0]) == 0:
logger.Warnw(
"Client request not authorized, client must pass a valid client certificate for this operation",
"URL", req.URL,
"Method", req.Method,
"RemoteAddr", req.RemoteAddr,
)
w.WriteHeader(http.StatusUnauthorized)
default:
r.next.ServeHTTP(w, req)
}
}
Loading
Loading