Skip to content

Commit 4be80d6

Browse files
committed
contextjson: Add context marshaler/unmarshaler
1 parent 94f0582 commit 4be80d6

File tree

6 files changed

+191
-16
lines changed

6 files changed

+191
-16
lines changed

common/json/context_ext.go

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package json
2+
3+
import (
4+
"context"
5+
6+
"github.com/sagernet/sing/common/json/internal/contextjson"
7+
)
8+
9+
var (
10+
MarshalContext = json.MarshalContext
11+
UnmarshalContext = json.UnmarshalContext
12+
NewEncoderContext = json.NewEncoderContext
13+
NewDecoderContext = json.NewDecoderContext
14+
)
15+
16+
type ContextMarshaler interface {
17+
MarshalJSONContext(ctx context.Context) ([]byte, error)
18+
}
19+
20+
type ContextUnmarshaler interface {
21+
UnmarshalJSONContext(ctx context.Context, content []byte) error
22+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package json
2+
3+
import "context"
4+
5+
type ContextMarshaler interface {
6+
MarshalJSONContext(ctx context.Context) ([]byte, error)
7+
}
8+
9+
type ContextUnmarshaler interface {
10+
UnmarshalJSONContext(ctx context.Context, content []byte) error
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package json_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/sagernet/sing/common/json/internal/contextjson"
8+
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
type myStruct struct {
13+
value string
14+
}
15+
16+
func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
17+
return json.Marshal(ctx.Value("key").(string))
18+
}
19+
20+
func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
21+
m.value = ctx.Value("key").(string)
22+
return nil
23+
}
24+
25+
//nolint:staticcheck
26+
func TestMarshalContext(t *testing.T) {
27+
t.Parallel()
28+
ctx := context.WithValue(context.Background(), "key", "value")
29+
var s myStruct
30+
b, err := json.MarshalContext(ctx, &s)
31+
require.NoError(t, err)
32+
require.Equal(t, []byte(`"value"`), b)
33+
}
34+
35+
//nolint:staticcheck
36+
func TestUnmarshalContext(t *testing.T) {
37+
t.Parallel()
38+
ctx := context.WithValue(context.Background(), "key", "value")
39+
var s myStruct
40+
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
41+
require.NoError(t, err)
42+
require.Equal(t, "value", s.value)
43+
}

common/json/internal/contextjson/decode.go

+42-7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package json
99

1010
import (
11+
"context"
1112
"encoding"
1213
"encoding/base64"
1314
"fmt"
@@ -95,10 +96,15 @@ import (
9596
// Instead, they are replaced by the Unicode replacement
9697
// character U+FFFD.
9798
func Unmarshal(data []byte, v any) error {
99+
return UnmarshalContext(context.Background(), data, v)
100+
}
101+
102+
func UnmarshalContext(ctx context.Context, data []byte, v any) error {
98103
// Check for well-formedness.
99104
// Avoids filling out half a data structure
100105
// before discovering a JSON syntax error.
101106
var d decodeState
107+
d.ctx = ctx
102108
err := checkValid(data, &d.scan)
103109
if err != nil {
104110
return err
@@ -209,6 +215,7 @@ type errorContext struct {
209215

210216
// decodeState represents the state while decoding a JSON value.
211217
type decodeState struct {
218+
ctx context.Context
212219
data []byte
213220
off int // next read offset in data
214221
opcode int // last read result
@@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
428435
// If it encounters an Unmarshaler, indirect stops and returns that.
429436
// If decodingNull is true, indirect stops at the first settable pointer so it
430437
// can be set to nil.
431-
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
438+
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
432439
// Issue #24153 indicates that it is generally not a guaranteed property
433440
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
434441
// and expect the value to still be settable for values derived from
@@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
482489
}
483490
if v.Type().NumMethod() > 0 && v.CanInterface() {
484491
if u, ok := v.Interface().(Unmarshaler); ok {
485-
return u, nil, reflect.Value{}
492+
return u, nil, nil, reflect.Value{}
493+
}
494+
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
495+
return nil, cu, nil, reflect.Value{}
486496
}
487497
if !decodingNull {
488498
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
489-
return nil, u, reflect.Value{}
499+
return nil, nil, u, reflect.Value{}
490500
}
491501
}
492502
}
@@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
498508
v = v.Elem()
499509
}
500510
}
501-
return nil, nil, v
511+
return nil, nil, nil, v
502512
}
503513

504514
// array consumes an array from d.data[d.off-1:], decoding into v.
505515
// The first byte of the array ('[') has been read already.
506516
func (d *decodeState) array(v reflect.Value) error {
507517
// Check for unmarshaler.
508-
u, ut, pv := indirect(v, false)
518+
u, cu, ut, pv := indirect(v, false)
509519
if u != nil {
510520
start := d.readIndex()
511521
d.skip()
@@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
515525
}
516526
return nil
517527
}
528+
if cu != nil {
529+
start := d.readIndex()
530+
d.skip()
531+
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
532+
if err != nil {
533+
d.saveError(err)
534+
}
535+
return nil
536+
}
518537
if ut != nil {
519538
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
520539
d.skip()
@@ -612,7 +631,7 @@ var (
612631
// The first byte ('{') of the object has been read already.
613632
func (d *decodeState) object(v reflect.Value) error {
614633
// Check for unmarshaler.
615-
u, ut, pv := indirect(v, false)
634+
u, cu, ut, pv := indirect(v, false)
616635
if u != nil {
617636
start := d.readIndex()
618637
d.skip()
@@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
622641
}
623642
return nil
624643
}
644+
if cu != nil {
645+
start := d.readIndex()
646+
d.skip()
647+
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
648+
if err != nil {
649+
d.saveError(err)
650+
}
651+
return nil
652+
}
625653
if ut != nil {
626654
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
627655
d.skip()
@@ -870,14 +898,21 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
870898
return nil
871899
}
872900
isNull := item[0] == 'n' // null
873-
u, ut, pv := indirect(v, isNull)
901+
u, cu, ut, pv := indirect(v, isNull)
874902
if u != nil {
875903
err := u.UnmarshalJSON(item)
876904
if err != nil {
877905
d.saveError(err)
878906
}
879907
return nil
880908
}
909+
if cu != nil {
910+
err := cu.UnmarshalJSONContext(d.ctx, item)
911+
if err != nil {
912+
d.saveError(err)
913+
}
914+
return nil
915+
}
881916
if ut != nil {
882917
if item[0] != '"' {
883918
if fromQuoted {

common/json/internal/contextjson/encode.go

+60-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ package json
1212

1313
import (
1414
"bytes"
15+
"context"
1516
"encoding"
1617
"encoding/base64"
1718
"fmt"
@@ -156,7 +157,11 @@ import (
156157
// handle them. Passing cyclic structures to Marshal will result in
157158
// an error.
158159
func Marshal(v any) ([]byte, error) {
159-
e := newEncodeState()
160+
return MarshalContext(context.Background(), v)
161+
}
162+
163+
func MarshalContext(ctx context.Context, v any) ([]byte, error) {
164+
e := newEncodeState(ctx)
160165
defer encodeStatePool.Put(e)
161166

162167
err := e.marshal(v, encOpts{escapeHTML: true})
@@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
251256
type encodeState struct {
252257
bytes.Buffer // accumulated output
253258

259+
ctx context.Context
254260
// Keep track of what pointers we've seen in the current recursive call
255261
// path, to avoid cycles that could lead to a stack overflow. Only do
256262
// the relatively expensive map operations if ptrLevel is larger than
@@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000
264270

265271
var encodeStatePool sync.Pool
266272

267-
func newEncodeState() *encodeState {
273+
func newEncodeState(ctx context.Context) *encodeState {
268274
if v := encodeStatePool.Get(); v != nil {
269275
e := v.(*encodeState)
270276
e.Reset()
@@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
274280
e.ptrLevel = 0
275281
return e
276282
}
277-
return &encodeState{ptrSeen: make(map[any]struct{})}
283+
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
278284
}
279285

280286
// jsonError is an error wrapper type for internal use only.
@@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
371377
}
372378

373379
var (
374-
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
375-
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
380+
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
381+
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
382+
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
376383
)
377384

378385
// newTypeEncoder constructs an encoderFunc for a type.
@@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
385392
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
386393
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
387394
}
395+
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
396+
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
397+
}
388398
if t.Implements(marshalerType) {
389399
return marshalerEncoder
390400
}
401+
if t.Implements(contextMarshalerType) {
402+
return contextMarshalerEncoder
403+
}
391404
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
392405
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
393406
}
@@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
470483
}
471484
}
472485

486+
func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
487+
if v.Kind() == reflect.Pointer && v.IsNil() {
488+
e.WriteString("null")
489+
return
490+
}
491+
m, ok := v.Interface().(ContextMarshaler)
492+
if !ok {
493+
e.WriteString("null")
494+
return
495+
}
496+
b, err := m.MarshalJSONContext(e.ctx)
497+
if err == nil {
498+
e.Grow(len(b))
499+
out := availableBuffer(&e.Buffer)
500+
out, err = appendCompact(out, b, opts.escapeHTML)
501+
e.Buffer.Write(out)
502+
}
503+
if err != nil {
504+
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
505+
}
506+
}
507+
508+
func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
509+
va := v.Addr()
510+
if va.IsNil() {
511+
e.WriteString("null")
512+
return
513+
}
514+
m := va.Interface().(ContextMarshaler)
515+
b, err := m.MarshalJSONContext(e.ctx)
516+
if err == nil {
517+
e.Grow(len(b))
518+
out := availableBuffer(&e.Buffer)
519+
out, err = appendCompact(out, b, opts.escapeHTML)
520+
e.Buffer.Write(out)
521+
}
522+
if err != nil {
523+
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
524+
}
525+
}
526+
473527
func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
474528
if v.Kind() == reflect.Pointer && v.IsNil() {
475529
e.WriteString("null")
@@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
827881
// Byte slices get special treatment; arrays don't.
828882
if t.Elem().Kind() == reflect.Uint8 {
829883
p := reflect.PointerTo(t.Elem())
830-
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
884+
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
831885
return encodeByteSlice
832886
}
833887
}

0 commit comments

Comments
 (0)