Skip to content

Add an ability to control response HTTP headers from within json-rpc method handlers #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package zenrpc

import (
"encoding/json"
"io/ioutil"
"io"
"net/http"
"strings"
"time"
Expand All @@ -16,7 +16,7 @@ type Printer interface {

// ServeHTTP process JSON-RPC 2.0 requests via HTTP.
// http://www.simple-is-better.org/json-rpc/transport_http.html
func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// check for CORS GET & POST requests
if s.options.AllowCORS {
w.Header().Set("Access-Control-Allow-Origin", "*")
Expand Down Expand Up @@ -54,14 +54,14 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

// ok, method is POST and content-type is application/json, process body
b, err := ioutil.ReadAll(r.Body)
b, err := io.ReadAll(r.Body)
var data interface{}

if err != nil {
s.printf("read request body failed with err=%v", err)
data = NewResponseError(nil, ParseError, "", nil)
} else {
data = s.process(newRequestContext(r.Context(), r), b)
data = s.process(newRequestResponseContext(r.Context(), r, w), b)
}

// if responses is empty -> all requests are notifications -> exit immediately
Expand All @@ -86,7 +86,7 @@ func (s Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// ServeWS processes JSON-RPC 2.0 requests via Gorilla WebSocket.
// https://github.com/gorilla/websocket/blob/master/examples/echo/
func (s Server) ServeWS(w http.ResponseWriter, r *http.Request) {
func (s *Server) ServeWS(w http.ResponseWriter, r *http.Request) {
c, err := s.options.Upgrader.Upgrade(w, r, nil)
if err != nil {
s.printf("upgrade connection failed with err=%v", err)
Expand All @@ -107,7 +107,7 @@ func (s Server) ServeWS(w http.ResponseWriter, r *http.Request) {
break
}

data, err := s.Do(newRequestContext(r.Context(), r), message)
data, err := s.Do(newRequestResponseContext(r.Context(), r, w), message)
if err != nil {
s.printf("marshal json response failed with err=%v", err)
c.WriteControl(websocket.CloseInternalServerErr, nil, time.Time{})
Expand Down
26 changes: 18 additions & 8 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"unicode"

"github.com/gorilla/websocket"

"github.com/semrush/zenrpc/v2/smd"
)

Expand All @@ -26,6 +27,9 @@ const (
// context key for http.Request object.
requestKey contextKey = "request"

// context key for http.ResponseWriter implementation.
responseWriterKey contextKey = "responseWriter"

// context key for namespace.
namespaceKey contextKey = "namespace"

Expand Down Expand Up @@ -165,7 +169,7 @@ func (s *Server) process(ctx context.Context, message json.RawMessage) interface
}

// processBatch process batch requests with context.
func (s Server) processBatch(ctx context.Context, requests []Request) []Response {
func (s *Server) processBatch(ctx context.Context, requests []Request) []Response {
reqLen := len(requests)

// running requests in batch asynchronously
Expand Down Expand Up @@ -206,7 +210,7 @@ func (s Server) processBatch(ctx context.Context, requests []Request) []Response
}

// processRequest processes a single request in service invoker.
func (s Server) processRequest(ctx context.Context, req Request) Response {
func (s *Server) processRequest(ctx context.Context, req Request) Response {
// checks for json-rpc version and method
if req.Version != Version || req.Method == "" {
return NewResponseError(req.ID, InvalidRequest, "", nil)
Expand Down Expand Up @@ -248,18 +252,18 @@ func (s Server) processRequest(ctx context.Context, req Request) Response {
}

// Do process JSON-RPC 2.0 request, invokes correct method for namespace and returns JSON-RPC 2.0 Response or marshaller error.
func (s Server) Do(ctx context.Context, req []byte) ([]byte, error) {
func (s *Server) Do(ctx context.Context, req []byte) ([]byte, error) {
return json.Marshal(s.process(ctx, req))
}

func (s Server) printf(format string, v ...interface{}) {
func (s *Server) printf(format string, v ...interface{}) {
if s.logger != nil {
s.logger.Printf(format, v...)
}
}

// SMD returns Service Mapping Description object with all registered methods.
func (s Server) SMD() smd.Schema {
func (s *Server) SMD() smd.Schema {
sch := smd.Schema{
Transport: "POST",
Envelope: "JSON-RPC-2.0",
Expand Down Expand Up @@ -346,9 +350,9 @@ func ConvertToObject(keys []string, params json.RawMessage) (json.RawMessage, er
return buf.Bytes(), nil
}

// newRequestContext creates new context with http.Request.
func newRequestContext(ctx context.Context, req *http.Request) context.Context {
return context.WithValue(ctx, requestKey, req)
// newRequestResponseContext creates new context with http.Request and http.ResponseWriter.
func newRequestResponseContext(ctx context.Context, req *http.Request, resp http.ResponseWriter) context.Context {
return context.WithValue(context.WithValue(ctx, responseWriterKey, resp), requestKey, req)
}

// RequestFromContext returns http.Request from context.
Expand All @@ -357,6 +361,12 @@ func RequestFromContext(ctx context.Context) (*http.Request, bool) {
return r, ok
}

// ResponseHeadersFromContext returns headers map to be sent with HTTP response of passed context.
func ResponseHeadersFromContext(ctx context.Context) (http.Header, bool) {
r, ok := ctx.Value(responseWriterKey).(http.ResponseWriter)
return r.Header(), ok
}

// newNamespaceContext creates new context with current method namespace.
func newNamespaceContext(ctx context.Context, namespace string) context.Context {
return context.WithValue(ctx, namespaceKey, namespace)
Expand Down