Source code for arcaflow_plugin_sdk.test_schema

import dataclasses
import doctest
import enum
import re
import typing
import unittest
from dataclasses import dataclass
from re import Pattern

from arcaflow_plugin_sdk import schema
from arcaflow_plugin_sdk.schema import (
    BadArgumentException,
    ConstraintException,
    PropertyType,
    SchemaBuildException,
)


[docs] class Color(enum.Enum): GREEN = "green" RED = "red"
[docs] class EnumTest(unittest.TestCase):
[docs] def test_unserialize(self): t = schema.StringEnumType(Color) self.assertEqual(Color.GREEN, t.unserialize("green")) self.assertEqual(Color.RED, t.unserialize("red")) self.assertEqual(Color.GREEN, t.unserialize(Color.GREEN)) self.assertEqual(Color.RED, t.unserialize(Color.RED)) try: t.unserialize("blue") self.fail("Invalid enum value didn't fail.") except schema.ConstraintException: pass class DifferentColor(enum.Enum): BLUE = "blue" try: t.unserialize(DifferentColor.BLUE) self.fail("Invalid enum value didn't fail.") except schema.ConstraintException: pass with self.assertRaises(schema.BadArgumentException): class BadEnum(enum.Enum): A = "foo" B = False schema.StringEnumType(BadEnum)
[docs] class BoolTest(unittest.TestCase):
[docs] def test_unserialize(self): t = schema.BoolType() self.assertEqual(False, t.unserialize("false")) self.assertEqual(False, t.unserialize("no")) self.assertEqual(False, t.unserialize("off")) self.assertEqual(False, t.unserialize("disable")) self.assertEqual(False, t.unserialize("disabled")) self.assertEqual(False, t.unserialize("0")) self.assertEqual(False, t.unserialize(0)) self.assertEqual(False, t.unserialize(False)) self.assertEqual(True, t.unserialize("true")) self.assertEqual(True, t.unserialize("yes")) self.assertEqual(True, t.unserialize("Yes")) self.assertEqual(True, t.unserialize("YES")) self.assertEqual(True, t.unserialize("on")) self.assertEqual(True, t.unserialize("enable")) self.assertEqual(True, t.unserialize("enabled")) self.assertEqual(True, t.unserialize("1")) self.assertEqual(True, t.unserialize(1)) self.assertEqual(True, t.unserialize(True)) with self.assertRaises(ConstraintException): t.unserialize(3.14) with self.assertRaises(ConstraintException): t.unserialize("")
[docs] def test_serialize(self): t = schema.BoolType() self.assertEqual(False, t.serialize(False)) self.assertEqual(True, t.serialize(True)) with self.assertRaises(ConstraintException): t.serialize(3.14) with self.assertRaises(ConstraintException): t.serialize("yes")
[docs] def test_validate(self): t = schema.BoolType() t.validate(False) t.validate(True) with self.assertRaises(ConstraintException): t.validate(3.14) with self.assertRaises(ConstraintException): t.validate("yes")
[docs] class StringTest(unittest.TestCase):
[docs] def test_validator_assignment(self): t = schema.StringType() t.min_length = 1 t.max_length = 2 t.pattern = re.compile("^[a-z]$") self.assertEqual(1, t.min_length) self.assertEqual(2, t.max_length) self.assertEqual("^[a-z]$", t.pattern.pattern)
[docs] def test_validation(self): t = schema.StringType() t.unserialize("") t.unserialize("Hello world!")
[docs] def test_validation_min_length(self): t = schema.StringType( min=1, ) with self.assertRaises(schema.ConstraintException): t.unserialize("") t.unserialize("A")
[docs] def test_validation_max_length(self): t = schema.StringType( max=1, ) with self.assertRaises(schema.ConstraintException): t.unserialize("ab") t.unserialize("a")
[docs] def test_validation_pattern(self): t = schema.StringType(pattern=re.compile("^[a-zA-Z]$")) with self.assertRaises(schema.ConstraintException): t.unserialize("ab1") t.unserialize("a")
[docs] def test_unserialize(self): @dataclasses.dataclass class InvalidType: pass t = schema.StringType() self.assertEqual("asdf", t.unserialize("asdf")) with self.assertRaises(schema.ConstraintException): t.unserialize(InvalidType())
[docs] class IntTest(unittest.TestCase):
[docs] def test_assignment(self): t = schema.IntType() t.min = 1 t.max = 2 self.assertEqual(1, t.min) self.assertEqual(2, t.max)
[docs] def unserialize(self): t = schema.IntType() self.assertEqual(0, t.unserialize(0)) self.assertEqual(-1, t.unserialize(-1)) self.assertEqual(1, t.unserialize(1)) with self.assertRaises(schema.ConstraintException): t.unserialize("1")
[docs] def test_validation_min(self): t = schema.IntType(min=1) t.unserialize(2) t.unserialize(1) with self.assertRaises(schema.ConstraintException): t.unserialize(0)
[docs] def test_validation_max(self): t = schema.IntType(max=1) t.unserialize(0) t.unserialize(1) with self.assertRaises(schema.ConstraintException): t.unserialize(2)
[docs] def test_unserialize(self): t = schema.IntType() self.assertEqual(1, t.unserialize(1)) with self.assertRaises(schema.ConstraintException): t.unserialize("asdf")
[docs] class ListTest(unittest.TestCase):
[docs] def test_assignement(self): t = schema.ListType(schema.StringType()) t.min = 1 t.max = 2 self.assertEqual(1, t.min) self.assertEqual(2, t.max)
[docs] def test_validation(self): @dataclasses.dataclass class BadData: pass t = schema.ListType(schema.StringType()) t.unserialize(["foo"]) with self.assertRaises(schema.ConstraintException): t.unserialize([BadData()]) with self.assertRaises(schema.ConstraintException): t.unserialize("5") with self.assertRaises(schema.ConstraintException): t.unserialize(BadData())
[docs] def test_validation_elements(self): t = schema.ListType(schema.StringType(min=5)) with self.assertRaises(schema.ConstraintException): t.unserialize(["foo"])
[docs] def test_validation_min(self): t = schema.ListType( schema.StringType(), min=3, ) with self.assertRaises(schema.ConstraintException): t.unserialize(["foo"])
[docs] def test_validation_max(self): t = schema.ListType( schema.StringType(), max=0, ) with self.assertRaises(schema.ConstraintException): t.unserialize(["foo"])
[docs] class MapTest(unittest.TestCase):
[docs] def test_assignment(self): t = schema.MapType(schema.StringType(), schema.StringType()) t.min = 1 t.max = 2 self.assertEqual(1, t.min) self.assertEqual(2, t.max)
[docs] def test_type_validation(self): @dataclasses.dataclass(frozen=True) class InvalidData: a: str t = schema.MapType(schema.StringType(), schema.StringType()) t.unserialize({}) t.validate({}) t.unserialize({"foo": "bar"}) t.validate({"foo": "bar"}) with self.assertRaises(schema.ConstraintException): t.unserialize({"foo": "bar", "baz": InvalidData("bar")}) with self.assertRaises(schema.ConstraintException): t.validate({"foo": "bar", "baz": InvalidData("bar")}) with self.assertRaises(schema.ConstraintException): t.unserialize({"foo": "bar", InvalidData("baz"): "baz"}) with self.assertRaises(schema.ConstraintException): t.validate({"foo": "bar", InvalidData("baz"): "baz"})
[docs] def test_validation_min(self): t = schema.MapType( schema.StringType(), schema.StringType(), min=1, ) t.unserialize({"foo": "bar"}) t.validate({"foo": "bar"}) with self.assertRaises(schema.ConstraintException): t.unserialize({}) with self.assertRaises(schema.ConstraintException): t.validate({})
[docs] def test_validation_max(self): t = schema.MapType( schema.StringType(), schema.StringType(), max=1, ) t.unserialize({"foo": "bar"}) t.validate({"foo": "bar"}) with self.assertRaises(schema.ConstraintException): t.unserialize({"foo": "bar", "baz": "Hello world!"}) with self.assertRaises(schema.ConstraintException): t.validate({"foo": "bar", "baz": "Hello world!"})
[docs] class AnyTest(unittest.TestCase):
[docs] def test_unserialize(self): t = schema.AnyType() self.assertEqual(1, t.unserialize(1)) self.assertEqual(1.0, t.unserialize(1.0)) self.assertEqual(True, t.unserialize(True)) self.assertEqual("True", t.unserialize("True")) self.assertEqual(None, t.unserialize(None)) self.assertEqual({"a": "b"}, t.unserialize({"a": "b"})) t.validate({"foo": "bar"}) with self.assertRaises(schema.ConstraintException): t.unserialize([1, 1.0, None, "a"]) self.assertEqual([{0: ["a"]}], t.unserialize([{0: ["a"]}])) class IsAClass: pass with self.assertRaises(ConstraintException): t.unserialize(set()) with self.assertRaises(ConstraintException): t.unserialize(IsAClass())
[docs] def test_validate(self): t = schema.AnyType() t.validate("Hello world!") t.validate(1) t.validate(1.0) t.validate(True) t.validate({"message": "Hello world!"}) t.validate(["Hello world!"]) t.validate(None)
[docs] @dataclass class TestClass: a: str b: int c: float d: bool
[docs] class ObjectTest(unittest.TestCase): t: schema.ObjectType[TestClass] = schema.ObjectType( TestClass, { "a": schema.PropertyType( schema.StringType(), required=True, ), "b": schema.PropertyType( schema.IntType(), required=True, ), "c": schema.PropertyType(schema.FloatType(), required=True), "d": schema.PropertyType(schema.BoolType(), required=True), }, )
[docs] def test_serialize(self): o = TestClass("foo", 5, 3.14, True) d = self.t.serialize(o) self.assertEqual({"a": "foo", "b": 5, "c": 3.14, "d": True}, d) o.b = None with self.assertRaises(schema.ConstraintException): self.t.serialize(o)
[docs] def test_validate(self): o = TestClass("a", 5, 3.14, True) self.t.validate(o) o.b = None with self.assertRaises(schema.ConstraintException): self.t.validate(o)
[docs] def test_unserialize(self): o = self.t.unserialize({"a": "foo", "b": 5, "c": 3.14, "d": True}) self.assertEqual("foo", o.a) self.assertEqual(5, o.b) self.assertEqual(3.14, o.c) self.assertEqual(True, o.d) with self.assertRaises(schema.ConstraintException): self.t.unserialize( { "a": "foo", } ) with self.assertRaises(schema.ConstraintException): self.t.unserialize( { "b": 5, } ) with self.assertRaises(schema.ConstraintException): self.t.unserialize( {"a": "foo", "b": 5, "c": 3.14, "d": True, "e": complex(3.14)} )
[docs] def test_field_override(self): @dataclasses.dataclass class TestData: a: str s: schema.ObjectType[TestData] = schema.ObjectType( TestData, { "test-data": schema.PropertyType( schema.StringType(), required=True, field_override="a" ) }, ) unserialized = s.unserialize({"test-data": "foo"}) self.assertEqual(unserialized.a, "foo") serialized = s.serialize(unserialized) self.assertEqual("foo", serialized["test-data"])
[docs] def test_init_mismatches(self): @dataclasses.dataclass class TestData2: a: str b: str def __init__(self, c: str, a: str): self.a = a self.b = c for name, cls in {"name-mismatch": TestData2}.items(): with self.subTest(name): with self.assertRaises(BadArgumentException): schema.ObjectType( cls, { "a": schema.PropertyType( schema.StringType(), ), "b": schema.PropertyType( schema.StringType(), ), }, )
[docs] def test_baseclass_field(self): @dataclasses.dataclass class TestParent: a: str @dataclasses.dataclass class TestSubclass(TestParent): pass # If a is missing from 'TestSubclass', it will fail. schema.ObjectType( TestSubclass, { "a": schema.PropertyType( schema.StringType(), ) }, )
[docs] class OneOfTest(unittest.TestCase):
[docs] def test_assignment(self): @dataclasses.dataclass class OneOfData1: a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType( { "a": schema.ObjectType( OneOfData1, {"a": PropertyType(schema.StringType())} ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, "a", ) s_type = schema.OneOfStringType( {"a": schema.RefType("a", scope), "b": schema.RefType("b", scope)}, scope, "_type", ) s_type.discriminator_field_name = "foo" self.assertEqual("foo", s_type.discriminator_field_name) schema.OneOfIntType( {1: schema.RefType(1, scope), 2: schema.RefType(2, scope)}, scope, "_type", )
[docs] def test_unserialize(self): @dataclasses.dataclass class OneOfData1: a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType( { "a": schema.ObjectType( OneOfData1, {"a": PropertyType(schema.StringType())} ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, "a", ) s_type = schema.OneOfStringType( {"a": schema.RefType("a", scope), "b": schema.RefType("b", scope)}, scope, "_type", ) # Incomplete values to unserialize with self.assertRaises(ConstraintException): s_type.unserialize({"a": "Hello world!"}) with self.assertRaises(ConstraintException): s_type.unserialize({"b": 42}) # Mismatching key value with self.assertRaises(ConstraintException): s_type.unserialize({"_type": "a", 1: "Hello world!"}) # Invalid key value with self.assertRaises(ConstraintException): s_type.unserialize({"_type": 1, 1: "Hello world!"}) unserialized_data: OneOfData1 = s_type.unserialize( {"_type": "a", "a": "Hello world!"} ) self.assertIsInstance(unserialized_data, OneOfData1) self.assertEqual(unserialized_data.a, "Hello world!") unserialized_data2: OneOfData2 = s_type.unserialize({"_type": "b", "b": 42}) self.assertIsInstance(unserialized_data2, OneOfData2) self.assertEqual(unserialized_data2.b, 42)
[docs] def test_unserialize_embedded(self): @dataclasses.dataclass class OneOfData1: type: str a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType( { "a": schema.ObjectType( OneOfData1, { "type": PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), }, ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, "a", ) s = schema.OneOfStringType( {"a": schema.RefType("a", scope), "b": schema.RefType("b", scope)}, scope, "type", ) unserialized_data: OneOfData1 = s.unserialize( {"type": "a", "a": "Hello world!"} ) self.assertIsInstance(unserialized_data, OneOfData1) self.assertEqual(unserialized_data.type, "a") self.assertEqual(unserialized_data.a, "Hello world!") unserialized_data2: OneOfData2 = s.unserialize({"type": "b", "b": 42}) self.assertIsInstance(unserialized_data2, OneOfData2) self.assertEqual(unserialized_data2.b, 42)
[docs] def test_validation(self): @dataclasses.dataclass class OneOfData1: type: str a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType( { "a": schema.ObjectType( OneOfData1, { "type": PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), }, ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, "a", ) s = schema.OneOfStringType[OneOfData1]( {"a": schema.RefType("a", scope), "b": schema.RefType("b", scope)}, scope, "type", ) with self.assertRaises(ConstraintException): # noinspection PyTypeChecker s.validate(OneOfData1(None, "Hello world!")) with self.assertRaises(ConstraintException): s.validate(OneOfData1("b", "Hello world!")) s.validate(OneOfData1("a", "Hello world!"))
[docs] def test_serialize(self): @dataclasses.dataclass class OneOfData1: type: str a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType( { "a": schema.ObjectType( OneOfData1, { "type": PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), }, ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, "a", ) s = schema.OneOfStringType( {"a": schema.RefType("a", scope), "b": schema.RefType("b", scope)}, scope, "type", ) self.assertEqual( s.serialize(OneOfData1("a", "Hello world!")), {"type": "a", "a": "Hello world!"}, ) self.assertEqual(s.serialize(OneOfData2(42)), {"type": "b", "b": 42})
[docs] def test_object(self): @dataclasses.dataclass class OneOfData1: type: str a: str @dataclasses.dataclass class OneOfData2: b: int scope = schema.ScopeType({}, "") s = schema.OneOfStringType( { "a": schema.ObjectType( OneOfData1, { "type": PropertyType( schema.StringType(), ), "a": PropertyType(schema.StringType()), }, ), "b": schema.ObjectType( OneOfData2, {"b": PropertyType(schema.IntType())} ), }, scope, "type", ) unserialized_data = s.unserialize({"type": "b", "b": 42}) self.assertIsInstance(unserialized_data, OneOfData2)
[docs] class SerializationTest(unittest.TestCase):
[docs] def test_serialization_cycle(self): @dataclasses.dataclass class TestData1: A: str B: int C: typing.Dict[str, int] D: typing.List[str] H: float E: typing.Optional[str] = None F: typing.Annotated[typing.Optional[str], schema.min(3)] = None G: typing.Optional[str] = dataclasses.field( default="", metadata={"id": "test-field", "name": "G"} ) I: typing.Any = None schema.test_object_serialization( TestData1(A="Hello world!", B=5, C={}, D=[], H=3.14), self.fail, ) @dataclasses.dataclass class KillPodConfig: namespace_pattern: re.Pattern name_pattern: typing.Annotated[ typing.Optional[re.Pattern], schema.required_if_not("label_selector") ] = None kill: typing.Annotated[int, schema.min(1)] = dataclasses.field( default=1, metadata={ "name": "Number of pods to kill", "description": "How many pods should we attempt to kill?", }, ) label_selector: typing.Annotated[ typing.Optional[str], schema.min(1), schema.required_if_not("name_pattern"), ] = None kubeconfig_path: typing.Optional[str] = None schema.test_object_serialization( KillPodConfig( namespace_pattern=re.compile(".*"), name_pattern=re.compile(".*") ), self.fail, )
[docs] def test_required_if(self): @dataclasses.dataclass class TestData1: A: typing.Annotated[typing.Optional[str], schema.required_if("B")] = None B: typing.Optional[int] = None s = schema.build_object_schema(TestData1) unserialized = s.unserialize({}) self.assertIsNone(unserialized.A) self.assertIsNone(unserialized.B) unserialized = s.unserialize({"A": None, "B": None}) self.assertIsNone(unserialized.A) self.assertIsNone(unserialized.B) unserialized = s.unserialize({"A": "Foo"}) self.assertEqual(unserialized.A, "Foo") self.assertIsNone(unserialized.B) with self.assertRaises(schema.ConstraintException): s.unserialize({"B": "Foo"}) with self.assertRaises(schema.ConstraintException): s.validate(TestData1(B="Foo")) with self.assertRaises(schema.ConstraintException): s.serialize(TestData1(B="Foo"))
[docs] def test_required_if_not(self): @dataclasses.dataclass class TestData1: A: typing.Optional[str] = None B: typing.Annotated[typing.Optional[str], schema.required_if_not("A")] = None s = schema.build_object_schema(TestData1) with self.assertRaises(schema.ConstraintException): s.unserialize({}) with self.assertRaises(schema.ConstraintException): s.unserialize({"A": None, "B": None}) unserialized = s.unserialize({"A": "Foo"}) self.assertEqual(unserialized.A, "Foo") self.assertIsNone(unserialized.B) unserialized = s.unserialize({"B": "Foo"}) self.assertEqual(unserialized.B, "Foo") self.assertIsNone(unserialized.A) s.validate(TestData1(B="Foo")) s.serialize(TestData1(B="Foo")) @dataclasses.dataclass class TestData2: A: typing.Optional[str] = None B: typing.Optional[str] = None C: typing.Annotated[typing.Optional[str], schema.required_if_not("A"), schema.required_if_not("B")] = None s = schema.build_object_schema(TestData2) with self.assertRaises(schema.ConstraintException): s.unserialize({"A": None, "B": None, "C": None}) unserialized = s.unserialize({"C": "Foo"}) self.assertIsNone(unserialized.A) self.assertIsNone(unserialized.B) self.assertEqual(unserialized.C, "Foo") td2_c = TestData2(C="Foo") s.validate(td2_c) s.serialize(td2_c)
[docs] def test_int_optional(self): @dataclasses.dataclass class TestData1: A: typing.Optional[int] = None s = schema.build_object_schema(TestData1) unserialized = s.unserialize({}) self.assertIsNone(unserialized.A)
[docs] def test_float_optional(self): @dataclasses.dataclass class TestData1: A: typing.Optional[float] = None s = schema.build_object_schema(TestData1) unserialized = s.unserialize({}) self.assertIsNone(unserialized.A)
[docs] def test_build_object_schema_wrapping(self): @dataclasses.dataclass class TestData1: A: int s = schema.build_object_schema(TestData1) with self.assertRaises(ConstraintException) as ctx: s.unserialize({}) self.assertIn("TestData1", ctx.exception.__str__())
[docs] def test_default_value(self): class TestEnum(enum.Enum): A = "a" @dataclasses.dataclass class A: a: TestEnum = TestEnum.A s = schema.build_object_schema(A) data = s.unserialize({"a": "a"}) s.validate(data) serialized_data = s.serialize(data) self.assertEqual(serialized_data, {"a": "a"})
[docs] class SchemaBuilderTest(unittest.TestCase):
[docs] def test_any(self): scope = schema.ScopeType( {}, "a", ) resolved_type = schema._SchemaBuilder.resolve(typing.Any, scope) self.assertIsInstance(resolved_type, schema.AnyType)
[docs] def test_non_dataclass(self): scope = schema.ScopeType( {}, "a", ) with self.assertRaises(SchemaBuildException) as ctx: schema._SchemaBuilder.resolve(complex, scope) self.assertIn("complex numbers are not supported", ctx.exception.msg)
[docs] def test_regexp(self): scope = schema.ScopeType( {}, "a", ) resolved_type = schema._SchemaBuilder.resolve(Pattern, scope) self.assertIsInstance(resolved_type, schema.PatternType)
[docs] def test_string(self): scope = schema.ScopeType( {}, "a", ) test: str = "foo" resolved_type = schema._SchemaBuilder.resolve(type(test), scope) self.assertIsInstance(resolved_type, schema.StringType) resolved_type = schema._SchemaBuilder.resolve(test, scope) self.assertIsInstance(resolved_type, schema.StringType)
[docs] def test_int(self): scope = schema.ScopeType( {}, "a", ) test: int = 5 resolved_type = schema._SchemaBuilder.resolve(type(test), scope) self.assertIsInstance(resolved_type, schema.IntType) resolved_type = schema._SchemaBuilder.resolve(test, scope) self.assertIsInstance(resolved_type, schema.IntType)
[docs] def test_float(self): scope = schema.ScopeType( {}, "a", ) test: float = 3.14 resolved_type = schema._SchemaBuilder.resolve(type(test), scope) self.assertIsInstance(resolved_type, schema.FloatType) resolved_type = schema._SchemaBuilder.resolve(test, scope) self.assertIsInstance(resolved_type, schema.FloatType)
[docs] def test_string_enum(self): scope = schema.ScopeType( {}, "a", ) class TestEnum(enum.Enum): A = "a" B = "b" resolved_type = schema._SchemaBuilder.resolve(TestEnum, scope) self.assertIsInstance(resolved_type, schema.StringEnumType)
[docs] def test_int_enum(self): scope = schema.ScopeType( {}, "a", ) class TestEnum(enum.Enum): A = 1 B = 2 resolved_type = schema._SchemaBuilder.resolve(TestEnum, scope) self.assertIsInstance(resolved_type, schema.IntEnumType)
[docs] def test_list(self): scope = schema.ScopeType( {}, "a", ) resolved_type = schema._SchemaBuilder.resolve(typing.List[str], scope) self.assertIsInstance(resolved_type, schema.ListType) self.assertIsInstance(resolved_type.items, schema.StringType) test: list = [] with self.assertRaises(SchemaBuildException): schema._SchemaBuilder.resolve(type(test), scope)
[docs] def test_map(self): scope = schema.ScopeType( {}, "a", ) resolved_type = schema._SchemaBuilder.resolve(typing.Dict[str, str], scope) self.assertIsInstance(resolved_type, schema.MapType) self.assertIsInstance(resolved_type.keys, schema.StringType) self.assertIsInstance(resolved_type.values, schema.StringType) test: dict = {} with self.assertRaises(SchemaBuildException): schema._SchemaBuilder.resolve(type(test), scope) resolved_type = schema._SchemaBuilder.resolve(dict[str, str], scope) self.assertIsInstance(resolved_type, schema.MapType) self.assertIsInstance(resolved_type.keys, schema.StringType) self.assertIsInstance(resolved_type.values, schema.StringType)
[docs] def test_class(self): scope = schema.ScopeType( {}, "TestData", ) class TestData: a: str b: int c: float d: bool with self.assertRaises(SchemaBuildException): schema._SchemaBuilder.resolve(TestData, scope) @dataclasses.dataclass class TestData: a: str b: int c: float d: bool scope = schema.ScopeType( {}, "TestData", ) resolved_type = schema._SchemaBuilder.resolve(TestData, scope) self.assertIsInstance(resolved_type, schema.RefType) self.assertEqual(1, len(scope.objects)) object_schema = scope.objects["TestData"] self.assertIsInstance(object_schema, schema.ObjectType) self.assertIsNone(object_schema.properties["a"].display.name) self.assertTrue(object_schema.properties["a"].required) self.assertIsInstance(object_schema.properties["a"].type, schema.StringType) self.assertIsNone(object_schema.properties["b"].display.name) self.assertTrue(object_schema.properties["b"].required) self.assertIsInstance(object_schema.properties["b"].type, schema.IntType) self.assertIsNone(object_schema.properties["c"].display.name) self.assertTrue(object_schema.properties["c"].required) self.assertIsInstance(object_schema.properties["c"].type, schema.FloatType) self.assertIsNone(object_schema.properties["d"].display.name) self.assertTrue(object_schema.properties["d"].required) self.assertIsInstance(object_schema.properties["d"].type, schema.BoolType) @dataclasses.dataclass class TestData: a: str = "foo" b: int = 5 c: str = dataclasses.field( default="bar", metadata={"name": "C", "description": "A string"} ) d: bool = True scope = schema.ScopeType( {}, "TestData", ) resolved_type = schema._SchemaBuilder.resolve(TestData, scope) self.assertIsInstance(resolved_type, schema.RefType) self.assertEqual(1, len(scope.objects)) object_schema = scope.objects["TestData"] self.assertIsInstance(object_schema, schema.ObjectType) self.assertIsNone(object_schema.properties["a"].display.name) self.assertFalse(object_schema.properties["a"].required) self.assertIsInstance(object_schema.properties["a"].type, schema.StringType) self.assertIsNone(object_schema.properties["b"].display.name) self.assertFalse(object_schema.properties["b"].required) self.assertIsInstance(object_schema.properties["b"].type, schema.IntType) self.assertEqual("C", object_schema.properties["c"].display.name) self.assertEqual("A string", object_schema.properties["c"].display.description) self.assertFalse(object_schema.properties["c"].required) self.assertIsInstance(object_schema.properties["c"].type, schema.StringType) self.assertIsNone(object_schema.properties["d"].display.name) self.assertFalse(object_schema.properties["d"].required) self.assertIsInstance(object_schema.properties["d"].type, schema.BoolType)
[docs] def test_union(self): @dataclasses.dataclass class A: a: str @dataclasses.dataclass class B: b: str @dataclasses.dataclass class TestData: a: typing.Union[A, B] scope = schema.build_object_schema(TestData) self.assertEqual("TestData", scope.root) self.assertIsInstance(scope.objects["TestData"], schema.ObjectType) self.assertIsInstance(scope.objects["A"], schema.ObjectType) self.assertIsInstance(scope.objects["B"], schema.ObjectType) self.assertIsInstance( scope.objects["TestData"].properties["a"].type, schema.OneOfStringType ) one_of_type: schema.OneOfStringType = ( scope.objects["TestData"].properties["a"].type ) self.assertEqual(one_of_type.discriminator_field_name, "_type") self.assertIsInstance(one_of_type.types["A"], schema.RefType) self.assertEqual(one_of_type.types["A"].id, "A") self.assertIsInstance(one_of_type.types["B"], schema.RefType) self.assertEqual(one_of_type.types["B"].id, "B")
[docs] def test_union_custom_discriminator(self): @dataclasses.dataclass class A: discriminator: int a: str @dataclasses.dataclass class B: discriminator: int b: str @dataclasses.dataclass class TestData: a: typing.Annotated[ typing.Union[ typing.Annotated[A, schema.discriminator_value(1)], typing.Annotated[B, schema.discriminator_value(2)], ], schema.discriminator("discriminator"), ] scope = schema.build_object_schema(TestData) self.assertEqual("TestData", scope.root) self.assertIsInstance(scope.objects["TestData"], schema.ObjectType) self.assertIsInstance(scope.objects["A"], schema.ObjectType) self.assertIsInstance(scope.objects["B"], schema.ObjectType) self.assertIsInstance( scope.objects["TestData"].properties["a"].type, schema.OneOfIntType ) one_of_type: schema.OneOfIntType = ( scope.objects["TestData"].properties["a"].type ) self.assertEqual(one_of_type.discriminator_field_name, "discriminator") self.assertIsInstance(one_of_type.types[1], schema.RefType) self.assertEqual(one_of_type.types[1].id, "A") self.assertIsInstance(one_of_type.types[2], schema.RefType) self.assertEqual(one_of_type.types[2].id, "B")
[docs] def test_optional(self): @dataclasses.dataclass class TestData: a: typing.Optional[str] = None scope = schema.build_object_schema(TestData) self.assertEqual("TestData", scope.root) self.assertIsInstance(scope.objects["TestData"], schema.ObjectType) resolved_type = scope.objects["TestData"] self.assertFalse(resolved_type.properties["a"].required) self.assertIsInstance(resolved_type.properties["a"].type, schema.StringType)
[docs] def test_annotated(self): scope = schema.ScopeType( {}, "TestData", ) resolved_type = schema._SchemaBuilder.resolve( typing.Annotated[str, schema.min(3)], scope ) self.assertIsInstance(resolved_type, schema.StringType) self.assertEqual(3, resolved_type.min) @dataclasses.dataclass class TestData: a: typing.Annotated[typing.Optional[str], schema.min(3)] = None scope = schema.ScopeType( {}, "TestData", ) schema._SchemaBuilder.resolve(TestData, scope) resolved_type2 = scope.objects["TestData"] a = resolved_type2.properties["a"] self.assertIsInstance(a.type, schema.StringType) self.assertFalse(a.required) self.assertEqual(3, a.type.min) with self.assertRaises(SchemaBuildException): @dataclasses.dataclass class TestData: a: typing.Annotated[typing.Optional[str], "foo"] = None scope = schema.ScopeType( {}, "TestData", ) schema._SchemaBuilder.resolve(TestData, scope)
[docs] def test_annotated_required_if(self): @dataclasses.dataclass class TestData2: a: typing.Annotated[typing.Optional[str], schema.required_if("b")] = None b: typing.Optional[str] = None scope = schema.ScopeType( {}, "TestData2", ) schema._SchemaBuilder.resolve(TestData2, scope) t = scope.objects["TestData2"] a = t.properties["a"] b = t.properties["b"] self.assertFalse(a.required) self.assertFalse(b.required) self.assertEqual(["b"], a.required_if)
[docs] def test_different_id(self): @dataclasses.dataclass class TestData: a: str = dataclasses.field(metadata={"id": "test-field"}) scope = schema.ScopeType( {}, "TestData", ) schema._SchemaBuilder.resolve(TestData, scope) t = scope.objects["TestData"] a = t.properties["test-field"] self.assertEqual(a.field_override, "a")
[docs] def test_unclear_error_message(self): @dataclasses.dataclass class TestData: a: typing.Dict[str, str] scope = schema.ScopeType( {}, "TestData", ) schema._SchemaBuilder.resolve(TestData, scope) t = scope.objects["TestData"] with self.assertRaises(ConstraintException): # noinspection PyTypeChecker t.serialize(TestData(type(dict[str, str])))
[docs] class JSONSchemaTest(unittest.TestCase): def _execute_test_cases(self, test_cases): for name in test_cases.keys(): defs = schema._JSONSchemaDefs() scope = schema.ScopeType( {}, "a", ) with self.subTest(name=name): input = test_cases[name][0] expected = test_cases[name][1] self.assertEqual(expected, input._to_jsonschema_fragment(scope, defs))
[docs] def test_bool(self): defs = schema._JSONSchemaDefs() scope = schema.ScopeType( {}, "a", ) s = schema.BoolType()._to_jsonschema_fragment(scope, defs) self.assertEqual(s["anyOf"][0]["type"], "boolean") self.assertEqual(s["anyOf"][1]["type"], "string") self.assertEqual(s["anyOf"][2]["type"], "integer")
[docs] def test_string(self): test_cases: typing.Dict[str, typing.Tuple[schema.StringType, typing.Dict]] = { "base": (schema.StringType(), {"type": "string"}), "min": (schema.StringType(min=5), {"type": "string", "minLength": 5}), "max": (schema.StringType(max=5), {"type": "string", "maxLength": 5}), "pattern": ( schema.StringType(pattern=re.compile("^[a-z]+$")), {"type": "string", "pattern": "^[a-z]+$"}, ), } self._execute_test_cases(test_cases)
[docs] def test_int(self): test_cases: typing.Dict[str, typing.Tuple[schema.IntType, typing.Dict]] = { "base": (schema.IntType(), {"type": "integer"}), "min": (schema.IntType(min=5), {"type": "integer", "minimum": 5}), "max": (schema.IntType(max=5), {"type": "integer", "maximum": 5}), } self._execute_test_cases(test_cases)
[docs] def test_float(self): test_cases: typing.Dict[str, typing.Tuple[schema.FloatType, typing.Dict]] = { "base": (schema.FloatType(), {"type": "number"}), "min": (schema.FloatType(min=5.0), {"type": "number", "minimum": 5.0}), "max": (schema.FloatType(max=5.0), {"type": "number", "maximum": 5.0}), } self._execute_test_cases(test_cases)
[docs] def test_enum(self): class Color(enum.Enum): RED = "red" class Fibonacci(enum.Enum): FIRST = 1 SECOND = 2 test_cases: typing.Dict[str, typing.Tuple[schema._EnumType, typing.Dict]] = { "string": ( schema.StringEnumType(Color), {"type": "string", "enum": ["red"]}, ), "int": (schema.IntEnumType(Fibonacci), {"type": "integer", "enum": [1, 2]}), } self._execute_test_cases(test_cases)
[docs] def test_list(self): test_cases: typing.Dict[str, typing.Tuple[schema.ListType, typing.Dict]] = { "base": ( schema.ListType(schema.IntType()), {"type": "array", "items": {"type": "integer"}}, ), "min": ( schema.ListType(schema.IntType(), min=3), {"type": "array", "items": {"type": "integer"}, "minItems": 3}, ), "max": ( schema.ListType(schema.IntType(), max=3), {"type": "array", "items": {"type": "integer"}, "maxItems": 3}, ), } self._execute_test_cases(test_cases)
[docs] def test_map(self): test_cases: typing.Dict[str, typing.Tuple[schema.MapType, typing.Dict]] = { "base": ( schema.MapType(schema.IntType(), schema.StringType()), { "type": "object", "propertyNames": { "pattern": "^[0-9]+$", }, "additionalProperties": { "type": "string", }, }, ), "min": ( schema.MapType(schema.StringType(), schema.IntType(), min=3), { "type": "object", "propertyNames": {}, "additionalProperties": { "type": "integer", }, "minProperties": 3, }, ), "max": ( schema.MapType(schema.StringType(), schema.IntType(), max=3), { "type": "object", "propertyNames": {}, "additionalProperties": { "type": "integer", }, "maxProperties": 3, }, ), } self._execute_test_cases(test_cases)
[docs] def test_object(self): @dataclasses.dataclass class TestData: a: str scope = schema.ScopeType( {}, "TestData", ) scope.objects = { "TestData": schema.ObjectType( TestData, { "a": schema.PropertyType( schema.StringType(), display=schema.DisplayValue("A", "A string"), ) }, ) } defs = schema._JSONSchemaDefs() expected = { "$defs": { "TestData": { "type": "object", "properties": { "a": {"type": "string", "title": "A", "description": "A string"} }, "required": ["a"], "additionalProperties": False, "dependentRequired": {}, } }, "type": "object", "properties": { "a": {"type": "string", "title": "A", "description": "A string"} }, "required": ["a"], "additionalProperties": False, "dependentRequired": {}, } result = scope._to_jsonschema_fragment(scope, defs) self.assertEqual(expected, result)
[docs] def test_one_of(self): @dataclasses.dataclass class A: a: str @dataclasses.dataclass class B: b: str @dataclasses.dataclass class TestData: a: typing.Union[A, B] scope = schema.ScopeType( {}, "TestData", ) scope.objects = { "TestData": schema.ObjectType( TestData, { "a": schema.PropertyType( schema.OneOfStringType( { "a": schema.RefType("A", scope), "b": schema.RefType("B", scope), }, scope, "_type", ) ) }, ), "A": schema.ObjectType(A, {"a": schema.PropertyType(schema.StringType())}), "B": schema.ObjectType(B, {"b": schema.PropertyType(schema.StringType())}), } defs = schema._JSONSchemaDefs() json_schema = scope._to_jsonschema_fragment(scope, defs) self.assertEqual( { "$defs": { "TestData": { "type": "object", "properties": { "a": { "oneOf": [ {"$ref": "#/$defs/A_discriminated_string_a"}, {"$ref": "#/$defs/B_discriminated_string_b"}, ] } }, "required": ["a"], "additionalProperties": False, "dependentRequired": {}, }, "A": { "type": "object", "properties": { "a": {"type": "string"}, "_type": {"type": "string", "const": "a"}, }, "required": ["_type", "a"], "additionalProperties": False, "dependentRequired": {}, }, "A_discriminated_string_a": { "type": "object", "properties": { "a": {"type": "string"}, "_type": {"type": "string", "const": "a"}, }, "required": ["_type", "a"], "additionalProperties": False, "dependentRequired": {}, }, "B": { "type": "object", "properties": { "b": {"type": "string"}, "_type": {"type": "string", "const": "b"}, }, "required": ["_type", "b"], "additionalProperties": False, "dependentRequired": {}, }, "B_discriminated_string_b": { "type": "object", "properties": { "b": {"type": "string"}, "_type": {"type": "string", "const": "b"}, }, "required": ["_type", "b"], "additionalProperties": False, "dependentRequired": {}, }, }, "type": "object", "properties": { "a": { "oneOf": [ {"$ref": "#/$defs/A_discriminated_string_a"}, {"$ref": "#/$defs/B_discriminated_string_b"}, ] } }, "required": ["a"], "additionalProperties": False, "dependentRequired": {}, }, json_schema, )
[docs] def load_tests(loader, tests, ignore): """ This function adds the doctests to the discovery process. """ tests.addTests(doctest.DocTestSuite(schema)) return tests
if __name__ == "__main__": unittest.main()