diff --git a/kohi/base.py b/kohi/base.py index 864da4b..65723c3 100644 --- a/kohi/base.py +++ b/kohi/base.py @@ -57,6 +57,10 @@ def __init__( self._label: str = '' self._mutations: t.List[t.Callable[[t.Any], t.Any]] = [] + self._is_required = True + self._required_error = '{label} is a required field' + self._default_value = None + def __repr__(self): return f'<{self.__class__.__name__} of kohi>' @@ -89,9 +93,21 @@ def _run_validators(self, data: t.Any): if error: self._errors.append(error) - def _validate(self, data: t.Any): - self._run_validators(data) - self._handle_errors() + def _return_or_default(self, data: t.Any): + if not data is None: + return data + + if self._is_required and self._default_value is None: + error = self._required_error.format(label=self._label) + self._errors.insert(0, error) + return + + return copy.deepcopy(self._default_value) + + def _validate(self, data: t.Any, checks=True): + if checks and (data := self._return_or_default(data)) != None: + self._run_validators(data) + self._handle_errors() return len(self.errors) == 0 def validate(self, data: t.Any): @@ -101,11 +117,13 @@ def validate(self, data: t.Any): def parse(self, data: t.Any): """Analyzes the data and returns after passing the validation step""" - cloned = copy.deepcopy(data) + cloned = copy.deepcopy(self._return_or_default(data)) + + if (cloned != data and data == None) or (cloned == data == None and not self._is_required): + return cloned try: - if not self.validate(cloned): - self._handle_errors() + self.throw()._validate(cloned, False) except Exception as e: raise ParseError(str(e)) from e @@ -130,4 +148,17 @@ def throw(self): def label(self, text: str): self._label = text return self + + def default(self, data: t.Any): + self._default_value = data + return self + + def optional(self): + self._is_required = False + return self + + def required(self, message: str='{label} is a required field'): + self._is_required = True + self._required_error = message + return self \ No newline at end of file diff --git a/kohi/dictionary.py b/kohi/dictionary.py index f29f9f1..62c92fc 100644 --- a/kohi/dictionary.py +++ b/kohi/dictionary.py @@ -38,13 +38,6 @@ def validator(data, parent): self.add_validator(f'valide-{key}', validator) - def validate(self, data: t.Any): - """Validate the given data against the schema""" - self.reset() + def _validate(self, data: t.Any, checks=True): self._validate_props() - for validator in self._validators: - error = validator(data, self) # type: ignore - if error: - self._errors.append(error) - self._handle_errors() - return len(self.errors) == 0 \ No newline at end of file + return super()._validate(data, checks) \ No newline at end of file diff --git a/tests/test_base.py b/tests/test_base.py index bad8586..cda18c9 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,17 +1,14 @@ from kohi.base import BaseSchema from kohi.exceptions import ValidationError, ParseError -def assert_supress_error(fn, error=ValidationError, negate=False): - def wrap(*args, **kwargs): - try: - if negate: - assert not fn(*args, **kwargs) - else: - assert fn(*args, **kwargs) - except Exception as e: - assert isinstance(e, error) - - return wrap +def assert_raise(fn, error=ValidationError, negate=False): + try: + if negate: + assert not fn() + else: + assert fn() + except Exception as e: + assert isinstance(e, error) def test_add_validator(): b = BaseSchema(list) @@ -48,11 +45,8 @@ def test_custom_messages(): assert ob1.errors[0] == message.format(label='object_test', types='tuple') def test_raise(): - se = assert_supress_error(BaseSchema(tuple).throw().validate) - se(True) - - se = assert_supress_error(BaseSchema(str).throw().validate) - se(10) + assert_raise(lambda: BaseSchema(tuple).throw().validate(True)) + assert_raise(lambda: BaseSchema(str).throw().validate(10)) def test_parse(): b = BaseSchema(str) @@ -65,3 +59,10 @@ def test_mutations(): b.add_mutation(lambda a: a.replace('l', 'K')) assert b.parse('loki') == 'Kohi' + +def test_required(): + b = BaseSchema(int).label('test') + + assert not b.validate(None) + assert b.optional().parse(None) == None + assert_raise(lambda: b.required().parse(None), ParseError) diff --git a/tests/test_dictionary.py b/tests/test_dictionary.py index eee52f8..e6e919a 100644 --- a/tests/test_dictionary.py +++ b/tests/test_dictionary.py @@ -14,14 +14,9 @@ def test_base(): def test_errors_list(): d = DictSchema().props( name=StringSchema().length(4), - age=NumberSchema().lte(1) + age=NumberSchema().lte(1).default(1) ) assert d.validate({ 'name': 'kohi', 'age': 1 }) assert not d.validate({ 'name': 'Github', 'age': 15 }) - # assert not d.validate({ 'name': 'Gitlab' }) - -# def test_not_one_of(): -# assert EnumSchema().not_one_of([1]).validate('1') -# assert not EnumSchema().not_one_of([1]).validate(1) -# assert not EnumSchema().not_one_of(['1']).validate('1') \ No newline at end of file + assert d.throw().validate({ 'name': 'Gitlab' })