diff --git a/.travis.yml b/.travis.yml index 2e87b65..a8947cd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,24 +7,10 @@ branches: python: - "3.6" - - "2.7" - -env: - - MARSHMALLOW_VERSION=2.19.2 - - MARSHMALLOW_VERSION=3.0.0 - -matrix: - exclude: - - python: "2.7" - env: MARSHMALLOW_VERSION=3.0.0 install: - - pip install marshmallow==$MARSHMALLOW_VERSION - pip install -r requirements.txt -r test-requirements.txt -script: coverage run --source=marshmallow_objects setup.py test -after_success: codecov - stages: - pep8 - test @@ -36,6 +22,9 @@ jobs: script: flake8 --show-source install: pip install flake8 after_success: skip + - stage: test + script: coverage run --source=marshmallow_objects setup.py test + after_success: codecov - stage: deploy script: skip install: skip diff --git a/ChangeLog b/ChangeLog index efe1a66..1e91000 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,6 +1,8 @@ CHANGES ======= +* Drop support for marshmallow 2 + 1.0.23 ------ diff --git a/marshmallow_objects/models.py b/marshmallow_objects/models.py index 9873f47..3399a63 100644 --- a/marshmallow_objects/models.py +++ b/marshmallow_objects/models.py @@ -2,14 +2,9 @@ import contextlib import json import pprint -import sys import threading -try: - import configparser - import io -except ImportError: - import ConfigParser as configparser - import StringIO as io +import configparser +import io import marshmallow from marshmallow import fields @@ -18,10 +13,6 @@ except ImportError: pass -# Checking Marshmallow version -MM2 = marshmallow.__version__.startswith('2') -PY2 = int(sys.version_info[0]) == 2 - @marshmallow.post_load def __make_object__(self, data, **kwargs): @@ -51,7 +42,7 @@ def __new__(mcs, name, parents, dct): schema_fields[method_name] = dct[method_name] elif hasattr( - value, '__marshmallow_tags__' if MM2 else + value, '__marshmallow_hook__') or key in ('Meta', 'on_bind_field', 'handle_error'): schema_fields[key] = value @@ -72,6 +63,7 @@ def __new__(mcs, name, parents, dct): def __call__(cls, *args, **kwargs): if kwargs.pop('__post_load__', False): kwargs.pop("many", None) + kwargs.pop("unknown", None) schema = kwargs.pop('__schema__') obj = cls.__new__(cls, *args, **kwargs) obj.__dump_lock__ = threading.RLock() @@ -87,7 +79,14 @@ def __call__(cls, *args, **kwargs): context = kwargs.pop('context', None) partial = kwargs.pop('partial', None) many = kwargs.pop("many", None) - obj = cls.load(kwargs, many=many, context=context, partial=partial) + unknown = kwargs.pop('unknown', None) + obj = cls.load( + kwargs, + many=many, + context=context, + partial=partial, + unknown=unknown + ) return obj @@ -126,8 +125,6 @@ class Model(with_metaclass(ModelMeta)): @classmethod def __get_schema_class__(cls, **kwargs): - if MM2: - kwargs.setdefault('strict', True) return cls.__schema_class__(**kwargs) def __setattr_default__(self, key, value): @@ -179,18 +176,42 @@ def context(self, value): self.__schema__.context = value @classmethod - def load(cls, data, context=None, many=None, partial=None): + def _override_unknown(cls, schema, unknown): + setattr(schema, "_initial_unknown", schema.unknown) + schema.unknown = unknown + for field in schema.fields.values(): + if isinstance(field, fields.Nested): + cls._override_unknown(field.schema, unknown) + + @classmethod + def _restore_unknown(cls, schema): + if hasattr(schema, "_initial_unknown"): + schema.unknown = getattr(schema, "_initial_unknown") + delattr(schema, "_initial_unknown") + for field in schema.fields.values(): + if isinstance(field, fields.Nested): + cls._restore_unknown(field.schema) + + @classmethod + @contextlib.contextmanager + def propagate_unknwown(cls, schema, unknown=None): + if unknown: + cls._override_unknown(schema, unknown) + yield + cls._restore_unknown(schema) + else: + yield + + @classmethod + def load(cls, data, context=None, many=None, partial=None, unknown=None): schema = cls.__get_schema_class__(context=context, partial=partial) - loaded = schema.load(data, many=many) - if MM2: - return loaded[0] + with cls.propagate_unknwown(schema, unknown): + loaded = schema.load(data, many=many) return loaded def dump(self): with self.__dump_mode_on__(): dump = self.__schema__.dump(self) - if MM2: - return dump.data return dump @classmethod @@ -199,16 +220,16 @@ def load_json(cls, context=None, many=None, partial=None, + unknown=None, *args, **kwargs): schema = cls.__get_schema_class__(context=context) loaded = schema.loads(data, many=many, partial=partial, + unknown=unknown, *args, **kwargs) - if MM2: - return loaded[0] return loaded def dump_json(self): @@ -220,10 +241,17 @@ def load_yaml(cls, context=None, many=None, partial=None, + unknown=None, *args, **kwargs): - loaded = yaml.load(data, *args, **kwargs) - return cls.load(loaded, context=context, many=many, partial=partial) + loaded = yaml.load(data, Loader=yaml.FullLoader) + return cls.load( + loaded, + context=context, + many=many, + partial=partial, + unknown=unknown + ) def dump_yaml(self, default_flow_style=False): return yaml.dump(self.dump(), default_flow_style=default_flow_style) @@ -231,11 +259,7 @@ def dump_yaml(self, default_flow_style=False): @classmethod def load_ini(cls, data, context=None, partial=None, **kwargs): parser = configparser.ConfigParser(**kwargs) - if PY2: - fp = io.StringIO(data) - parser.readfp(fp) - else: - parser.read_string(data) + parser.read_string(data) ddata = {s: dict(parser.items(s)) for s in parser.sections()} ddata.update(parser.defaults()) return cls.load(ddata, context=context, partial=partial) @@ -258,8 +282,6 @@ def dump_ini(self, **kwargs): @classmethod def validate(cls, data, context=None, many=None, partial=None): kwargs = {'context': context} - if MM2: - kwargs['strict'] = False schema = cls.__get_schema_class__(**kwargs) return schema.validate(data, many=many, partial=partial) @@ -297,8 +319,6 @@ def dump_many(data, context=None): else: schema = obj.__get_schema_class__(context=context) obj_data = schema.dump(obj) - if MM2: - obj_data = obj_data[0] ret.append(obj_data) elif (isinstance(obj, collections.Sequence) and not isinstance(obj, str)): diff --git a/requirements.txt b/requirements.txt index 85efe39..030f364 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -marshmallow +marshmallow>=3.0.0 diff --git a/tests/test_models.py b/tests/test_models.py index 05b545a..45a62c0 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -28,7 +28,7 @@ class Meta: def on_bind_field(self, field_name, field_obj): pass - def handle_error(self, error, data): + def handle_error(self, error, data, many, partial): pass @@ -289,7 +289,7 @@ def test_load_yaml_partial(self): @unittest.skipIf(skip_yaml, 'PyYaml is not installed') def test_dump_yaml(self): a = A(test_field='foo') - ydata = yaml.load(a.dump_yaml()) + ydata = yaml.load(a.dump_yaml(), Loader=yaml.UnsafeLoader) self.assertEqual(self.data, ydata) def test_dump_ordered(self): @@ -298,6 +298,18 @@ def test_dump_ordered(self): self.assertIsInstance(a, collections.OrderedDict) self.assertIsInstance(b, dict) + def test_load_unknwon(self): + data = dict( + test_field='foo', + unknown_b="B", + a=dict(test_field='bar', unknown_b="B") + ) + with self.assertRaises(marshmallow.ValidationError): + B.load(data) + b = B.load(data, unknown=marshmallow.EXCLUDE) + self.assertEqual(b.test_field, 'foo') + self.assertEqual(b.a.test_field, 'bar') + class TestContext(unittest.TestCase): def setUp(self): @@ -439,7 +451,7 @@ def test_dump_json(self): def test_dump_yaml(self): bb = B.load(self.data, many=True) ydata = marshmallow.dump_many_yaml(bb) - ddata = yaml.load(ydata) + ddata = yaml.load(ydata, Loader=yaml.UnsafeLoader) self.assertEqual(self.data, ddata)