Skip to content

Commit bb0db42

Browse files
Fixed can_encode for JSONCodec and a pandas test
1 parent 94f812f commit bb0db42

File tree

3 files changed

+49
-9
lines changed

3 files changed

+49
-9
lines changed

mlserver/codecs/json.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,23 @@ def encode_to_json(v: Any, use_bytes: bool = True) -> Union[str, bytes]:
7878
return enc_v
7979

8080

81+
def _is_primitive(obj):
82+
return isinstance(obj, (int, float, str, bool, type(None)))
83+
84+
85+
def _is_nested_primitives(obj):
86+
if _is_primitive(obj):
87+
return True
88+
elif isinstance(obj, list):
89+
return all(_is_nested_primitives(item) for item in obj)
90+
elif isinstance(obj, dict):
91+
return all(
92+
isinstance(key, str) and _is_nested_primitives(value)
93+
for key, value in obj.items()
94+
)
95+
return False
96+
97+
8198
@register_input_codec
8299
class JSONCodec(InputCodec):
83100
"""
@@ -89,12 +106,23 @@ class JSONCodec(InputCodec):
89106

90107
@classmethod
91108
def can_encode(cls, payload: Any) -> bool:
92-
try:
93-
encode_to_json(payload)
94-
return True
95-
except Exception:
109+
is_json = all(_is_nested_primitives(item) for item in payload)
110+
111+
if not is_json:
96112
return False
97113

114+
all_primitive = all(_is_primitive(item) for item in payload)
115+
116+
# have to do it this way in case payload is not indexable
117+
types = [type(item) for item in payload]
118+
same_type = [types[0] == item for item in types]
119+
120+
# Don't want to json encode a list of primitives of the same type
121+
if all_primitive and same_type:
122+
return False
123+
124+
return True
125+
98126
@classmethod
99127
def encode_output(
100128
cls, name: str, payload: List[Any], use_bytes: bool = True, **kwargs

tests/codecs/test_json.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,17 @@ def test_encode_input(input: Any, expected: bytes):
4949
@pytest.mark.parametrize(
5050
"payload, expected",
5151
[
52-
([1, 2, 3], True),
53-
({"dummy_1": 1, "dummy_2": 2}, True),
54-
(np.array([1, 2, 3]), False),
55-
(set([1, 2, 3]), False),
52+
([[1, 2, 3]], True),
53+
([{"dummy_1": 1, "dummy_2": 2}], True),
54+
(
55+
[
56+
{"dummy_1": 1},
57+
{"dummy_2": [{"dummy_3": 3, "dummy_4": 4}, {"dummy_5": 5}]},
58+
],
59+
True,
60+
),
61+
([np.array([1, 2, 3])], False),
62+
([{1, 2, 3}], False),
5663
],
5764
)
5865
def test_json_codec_can_encode(payload: Any, expected: bool):

tests/codecs/test_pandas.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from mlserver.codecs.pandas import PandasCodec, _to_response_output
88
from mlserver.codecs.string import StringCodec
9+
from mlserver.codecs.json import JSONCodec
910
from mlserver.types import (
1011
InferenceRequest,
1112
InferenceResponse,
@@ -67,7 +68,11 @@ def test_can_encode(payload: Any, expected: bool):
6768
pd.Series(data=[[1, 2, 3], [4, 5, 6]], name="bar"),
6869
True,
6970
ResponseOutput(
70-
name="bar", shape=[2, 1], data=[[1, 2, 3], [4, 5, 6]], datatype="BYTES"
71+
name="bar",
72+
shape=[2, 1],
73+
data=[b"[1,2,3]", b"[4,5,6]"],
74+
datatype="BYTES",
75+
parameters=Parameters(content_type=JSONCodec.ContentType),
7176
),
7277
),
7378
(

0 commit comments

Comments
 (0)