Module: dict_serializer
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections import Counter
from dataclasses import dataclass
from enum import Enum
from typing import Dict
from typing import List
from typing import Tuple
from typing import Type
from typing import cast
from cl.runtime.backend.core.base_type_info import BaseTypeInfo
from cl.runtime.backend.core.tab_info import TabInfo
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.primitive.case_util import CaseUtil
from cl.runtime.records.protocols import TDataDict
from cl.runtime.records.protocols import TPrimitive
from cl.runtime.records.protocols import is_key
from cl.runtime.records.protocols import is_record
from cl.runtime.records.record_util import RecordUtil
from cl.runtime.serialization.sentinel_type import sentinel_value
# TODO: Initialize from settings
alias_dict: Dict[Type, str] = dict()
"""Dictionary of class name aliases using type as key (includes classes and enums with aliases only)."""
# TODO: Initialize from settings
_type_dict: Dict[str, Type] = None
"""Dictionary of types using class name or alias as key (includes all classes and enums)."""
class_hierarchy_slots_dict: Dict[Type, Tuple] = dict()
"""Dictionary of slots in class hierarchy in the order of declaration from base to derived."""
collect_slots = sys.version_info.major > 3 or sys.version_info.major == 3 and sys.version_info.minor >= 11
"""For Python 3.11 and later, __slots__ includes fields for this class only, use MRO to include base class slots."""
# TODO: Should classes not included packages be supported? If not do not update type dict in serializer.
def get_type_dict() -> Dict[str, Type]:
"""Load type dictionary from schema if not present."""
global _type_dict
if _type_dict is None:
from cl.runtime.schema.schema import Schema # TODO: Refactor to avoid cyclic dependency
_type_dict = Schema.get_type_dict()
# TODO (Roman): include all needed types to type_dict automatically
# Add data types needed for UiAppState deserialization to type_dict manually
for type_ in (TabInfo, BaseTypeInfo):
_type_dict[type_.__name__] = type_
return _type_dict
def _get_class_hierarchy_slots(data_type) -> Tuple[str]:
"""Tuple of slots in class hierarchy in the order of declaration from base to derived."""
if (result := class_hierarchy_slots_dict.get(data_type, None)) is not None:
# Use cached value
return result
else:
# Traverse the class hierarchy from base to derived (reverse MRO order) collecting slots as specified
if collect_slots:
# For v3.11 and later, __slots__ includes fields for this class only, use MRO to collect base class slots
# Exclude None or empty __slots__ (both are falsy)
slots_list = [slots for base in reversed(data_type.__mro__) if (slots := getattr(base, "__slots__", None))]
else:
# Otherwise get slots from this type only
# Exclude None or empty __slots__ (both are falsy)
slots_list = [slots if (slots := getattr(data_type, "__slots__", None)) else tuple()]
# Exclude empty tuples and convert slots specified as a single string into tuple of size one
slots_list = [(slots,) if isinstance(slots, str) else slots for slots in slots_list]
# Flatten and convert to tuple, cast relies on elements of sublist being strings
result = tuple(slot for sublist in slots_list for slot in sublist)
# Check for duplicates
if len(result) > len(set(result)):
# Error message if duplicates are found
counts = Counter(result)
duplicates = [slot for slot, count in counts.items() if count > 1]
duplicates_str = ", ".join(duplicates)
raise RuntimeError(
f"Duplicate field names found in class hierarchy " f"for {data_type.__name__}: {duplicates_str}."
)
class_hierarchy_slots_dict[data_type] = result
return cast(Tuple[str], result)
# TODO: Add checks for to_node, from_node implementation for custom override of default serializer
@dataclass(slots=True, kw_only=True)
class DictSerializer:
"""Serialization for slots-based classes (including dataclasses with slots=True)."""
pascalize_keys: bool = False
"""If true, pascalize keys during serialization."""
primitive_type_names = ["NoneType", "str", "float", "int", "bool", "date", "time", "datetime", "bytes", "UUID"]
"""Detect primitive type by checking if class name is in this list."""
def serialize_data(self, data, select_fields: List[str] | None = None): # TODO: Check if None should be supported
"""
Serialize to dictionary containing primitive types, dictionaries, or iterables.
Notes:
Before serialization, invoke 'init' for each class in class hierarchy that implements it,
in the order from base to derived.
Args:
data: Object to serialize
select_fields: Fields of data object which will be used for serialization. If None - use all fields.
"""
if getattr(data, "__slots__", None) is not None:
# Slots class, serialize as dictionary
# Invoke 'init' for each class in class hierarchy that implements it, in the order from base to derived
RecordUtil.init_all(data)
# Get slots from this class and its bases in the order of declaration from base to derived
all_slots = _get_class_hierarchy_slots(data.__class__)
# Serialize slot values in the order of declaration except those that are None
result = {
k if not self.pascalize_keys else CaseUtil.snake_to_pascal_case_keep_trailing_underscore(k): (
v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v)
)
for k in all_slots
if (not select_fields or k in select_fields) and (v := getattr(data, k)) is not None
}
# To find short name, use 'in' which is faster than 'get' when most types do not have aliases
short_name = alias_dict[type_] if (type_ := data.__class__) in alias_dict else type_.__name__
# Cache type for subsequent reverse lookup
type_dict = get_type_dict()
type_dict[short_name] = type_
# Add to result
result["_type"] = short_name
return result
elif isinstance(data, dict):
# Dictionary, return with serialized values
result = {
k: v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v)
for k, v in data.items()
}
return result
elif hasattr(data, "__iter__"):
# Get the first item without iterating over the entire sequence
first_item = next(iter(data), sentinel_value)
if first_item == sentinel_value:
# Empty iterable, return None
return None
elif first_item is not None and first_item.__class__.__name__ in self.primitive_type_names:
# Performance optimization to skip deserialization for arrays of primitive types
# based on the type of first item (assumes that all remaining items are also primitive)
return data
else:
# Serialize each element of the iterable
return [
v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v) for v in data
]
elif isinstance(data, Enum):
# Serialize enum as a dict using enum class short name and item name (rather than item value)
# To find short name, use 'in' which is faster than 'get' when most types do not have aliases
short_name = alias_dict[type_] if (type_ := type(data)) in alias_dict else type_.__name__
# Cache type for subsequent reverse lookup
type_dict = get_type_dict()
type_dict[short_name] = type_
pascal_case_value = CaseUtil.upper_to_pascal_case(data.name)
return {"_enum": short_name, "_name": pascal_case_value}
else:
raise RuntimeError(f"Cannot serialize data of type '{type(data)}'.")
def deserialize_data(self, data: TDataDict): # TODO: Check if None should be supported
"""Deserialize object from data, invoke init_all after deserialization."""
if isinstance(data, dict):
# Determine if the dictionary is a serialized dataclass or a dictionary
if (short_name := data.get("_type", None)) is not None:
# If _type is specified, create an instance of _type after deserializing fields recursively
type_dict = get_type_dict()
deserialized_type = type_dict.get(short_name, None) # noqa
if deserialized_type is None:
raise RuntimeError(
f"Class not found for name or alias '{short_name}' during deserialization. "
f"Ensure all serialized classes are included in package import settings."
)
# Check if the class is abstract
if RecordUtil.is_abstract(deserialized_type):
descendants = RecordUtil.get_non_abstract_descendants(deserialized_type)
descendant_names = sorted(set([x.__name__ for x in descendants]))
if len(descendant_names) > 0:
descendant_names_str = ", ".join(descendant_names)
raise UserError(
f"Record {deserialized_type.__name__} cannot be created directly, "
f"but the following descendant records can: {descendant_names_str}"
)
else:
raise UserError(
f"Record {deserialized_type.__name__} cannot be created directly "
f"and there are no descendant records that can."
)
deserialized_fields = {
CaseUtil.pascale_to_snake_case_keep_trailing_underscore(k) if self.pascalize_keys else k: (
v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v)
)
for k, v in data.items()
if k != "_type"
}
result = deserialized_type(**deserialized_fields) # noqa
# Invoke 'init' for each class in class hierarchy that implements it, in the order from base to derived
RecordUtil.init_all(result)
return result
elif (short_name := data.get("_enum", None)) is not None:
# If _enum is specified, create an instance of _enum using _name
type_dict = get_type_dict()
deserialized_type = type_dict.get(short_name, None) # noqa
if deserialized_type is None:
raise RuntimeError(
f"Enum not found for name or alias '{short_name}' during deserialization. "
f"Ensure all serialized enums are included in package import settings."
)
pascal_case_value = data["_name"]
upper_case_value = CaseUtil.pascal_to_upper_case(pascal_case_value)
result = deserialized_type[upper_case_value] # noqa
return result
else:
# Otherwise return a dictionary with recursively deserialized values
result = {
k: v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v)
for k, v in data.items()
}
return result
elif hasattr(data, "__iter__"):
# Get the first item without iterating over the entire sequence
first_item = next(iter(data), sentinel_value)
if first_item == sentinel_value:
# Empty iterable, return None
return None
elif first_item is not None and first_item.__class__.__name__ in self.primitive_type_names:
# Performance optimization to skip deserialization for arrays of primitive types
# based on the type of first item (assumes that all remaining items are also primitive)
return data
else:
# Deserialize each element of the iterable
return [
v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v) for v in data
]
elif is_key(data) or is_record(data):
return data
else:
raise RuntimeError(f"Cannot deserialize data of type '{type(data)}'.")
@classmethod
def _serialize_primitive(cls, value: TPrimitive, class_name: str) -> TPrimitive:
"""Serialize primitive value applying the applicable conversion rules."""
# TODO: Use switch statement
if class_name == "bool":
return "Y" if value else "N"
else:
return value
@classmethod
def _deserialize_primitive(cls, value: TPrimitive, class_name: str) -> TPrimitive:
"""Deserialize primitive value applying the applicable conversion rules."""
# TODO: Use switch statement
if class_name == "bool":
# TODO: Use switch statement
if isinstance(value, bool):
return value
elif value == "Y":
return True
elif value == "N":
return False
else:
raise RuntimeError(f"Serialized boolean field has value {value} but only Y or N are allowed.")
else:
return value
Global variables
var alias_dict
-
Dictionary of class name aliases using type as key (includes classes and enums with aliases only).
var class_hierarchy_slots_dict
-
Dictionary of slots in class hierarchy in the order of declaration from base to derived.
var collect_slots
-
For Python 3.11 and later, slots includes fields for this class only, use MRO to include base class slots.
Functions
def get_type_dict() -> Dict[str, Type]
-
Load type dictionary from schema if not present.
Classes
class DictSerializer (*, pascalize_keys: bool = False)
-
Serialization for slots-based classes (including dataclasses with slots=True).
Expand source code
@dataclass(slots=True, kw_only=True) class DictSerializer: """Serialization for slots-based classes (including dataclasses with slots=True).""" pascalize_keys: bool = False """If true, pascalize keys during serialization.""" primitive_type_names = ["NoneType", "str", "float", "int", "bool", "date", "time", "datetime", "bytes", "UUID"] """Detect primitive type by checking if class name is in this list.""" def serialize_data(self, data, select_fields: List[str] | None = None): # TODO: Check if None should be supported """ Serialize to dictionary containing primitive types, dictionaries, or iterables. Notes: Before serialization, invoke 'init' for each class in class hierarchy that implements it, in the order from base to derived. Args: data: Object to serialize select_fields: Fields of data object which will be used for serialization. If None - use all fields. """ if getattr(data, "__slots__", None) is not None: # Slots class, serialize as dictionary # Invoke 'init' for each class in class hierarchy that implements it, in the order from base to derived RecordUtil.init_all(data) # Get slots from this class and its bases in the order of declaration from base to derived all_slots = _get_class_hierarchy_slots(data.__class__) # Serialize slot values in the order of declaration except those that are None result = { k if not self.pascalize_keys else CaseUtil.snake_to_pascal_case_keep_trailing_underscore(k): ( v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v) ) for k in all_slots if (not select_fields or k in select_fields) and (v := getattr(data, k)) is not None } # To find short name, use 'in' which is faster than 'get' when most types do not have aliases short_name = alias_dict[type_] if (type_ := data.__class__) in alias_dict else type_.__name__ # Cache type for subsequent reverse lookup type_dict = get_type_dict() type_dict[short_name] = type_ # Add to result result["_type"] = short_name return result elif isinstance(data, dict): # Dictionary, return with serialized values result = { k: v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v) for k, v in data.items() } return result elif hasattr(data, "__iter__"): # Get the first item without iterating over the entire sequence first_item = next(iter(data), sentinel_value) if first_item == sentinel_value: # Empty iterable, return None return None elif first_item is not None and first_item.__class__.__name__ in self.primitive_type_names: # Performance optimization to skip deserialization for arrays of primitive types # based on the type of first item (assumes that all remaining items are also primitive) return data else: # Serialize each element of the iterable return [ v if v.__class__.__name__ in self.primitive_type_names else self.serialize_data(v) for v in data ] elif isinstance(data, Enum): # Serialize enum as a dict using enum class short name and item name (rather than item value) # To find short name, use 'in' which is faster than 'get' when most types do not have aliases short_name = alias_dict[type_] if (type_ := type(data)) in alias_dict else type_.__name__ # Cache type for subsequent reverse lookup type_dict = get_type_dict() type_dict[short_name] = type_ pascal_case_value = CaseUtil.upper_to_pascal_case(data.name) return {"_enum": short_name, "_name": pascal_case_value} else: raise RuntimeError(f"Cannot serialize data of type '{type(data)}'.") def deserialize_data(self, data: TDataDict): # TODO: Check if None should be supported """Deserialize object from data, invoke init_all after deserialization.""" if isinstance(data, dict): # Determine if the dictionary is a serialized dataclass or a dictionary if (short_name := data.get("_type", None)) is not None: # If _type is specified, create an instance of _type after deserializing fields recursively type_dict = get_type_dict() deserialized_type = type_dict.get(short_name, None) # noqa if deserialized_type is None: raise RuntimeError( f"Class not found for name or alias '{short_name}' during deserialization. " f"Ensure all serialized classes are included in package import settings." ) # Check if the class is abstract if RecordUtil.is_abstract(deserialized_type): descendants = RecordUtil.get_non_abstract_descendants(deserialized_type) descendant_names = sorted(set([x.__name__ for x in descendants])) if len(descendant_names) > 0: descendant_names_str = ", ".join(descendant_names) raise UserError( f"Record {deserialized_type.__name__} cannot be created directly, " f"but the following descendant records can: {descendant_names_str}" ) else: raise UserError( f"Record {deserialized_type.__name__} cannot be created directly " f"and there are no descendant records that can." ) deserialized_fields = { CaseUtil.pascale_to_snake_case_keep_trailing_underscore(k) if self.pascalize_keys else k: ( v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v) ) for k, v in data.items() if k != "_type" } result = deserialized_type(**deserialized_fields) # noqa # Invoke 'init' for each class in class hierarchy that implements it, in the order from base to derived RecordUtil.init_all(result) return result elif (short_name := data.get("_enum", None)) is not None: # If _enum is specified, create an instance of _enum using _name type_dict = get_type_dict() deserialized_type = type_dict.get(short_name, None) # noqa if deserialized_type is None: raise RuntimeError( f"Enum not found for name or alias '{short_name}' during deserialization. " f"Ensure all serialized enums are included in package import settings." ) pascal_case_value = data["_name"] upper_case_value = CaseUtil.pascal_to_upper_case(pascal_case_value) result = deserialized_type[upper_case_value] # noqa return result else: # Otherwise return a dictionary with recursively deserialized values result = { k: v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v) for k, v in data.items() } return result elif hasattr(data, "__iter__"): # Get the first item without iterating over the entire sequence first_item = next(iter(data), sentinel_value) if first_item == sentinel_value: # Empty iterable, return None return None elif first_item is not None and first_item.__class__.__name__ in self.primitive_type_names: # Performance optimization to skip deserialization for arrays of primitive types # based on the type of first item (assumes that all remaining items are also primitive) return data else: # Deserialize each element of the iterable return [ v if v.__class__.__name__ in self.primitive_type_names else self.deserialize_data(v) for v in data ] elif is_key(data) or is_record(data): return data else: raise RuntimeError(f"Cannot deserialize data of type '{type(data)}'.") @classmethod def _serialize_primitive(cls, value: TPrimitive, class_name: str) -> TPrimitive: """Serialize primitive value applying the applicable conversion rules.""" # TODO: Use switch statement if class_name == "bool": return "Y" if value else "N" else: return value @classmethod def _deserialize_primitive(cls, value: TPrimitive, class_name: str) -> TPrimitive: """Deserialize primitive value applying the applicable conversion rules.""" # TODO: Use switch statement if class_name == "bool": # TODO: Use switch statement if isinstance(value, bool): return value elif value == "Y": return True elif value == "N": return False else: raise RuntimeError(f"Serialized boolean field has value {value} but only Y or N are allowed.") else: return value
Subclasses
Class variables
var primitive_type_names
-
Detect primitive type by checking if class name is in this list.
Fields
var pascalize_keys -> bool
-
If true, pascalize keys during serialization.
Methods
def deserialize_data(self, data: Dict[str, Union[Dict[str, ForwardRef('TDataField')], List[ForwardRef('TDataField')], str, float, bool, int, datetime.date, datetime.time, datetime.datetime, uuid.UUID, bytes, ForwardRef(None), enum.Enum]])
-
Deserialize object from data, invoke init_all after deserialization.
def serialize_data(self, data, select_fields: Optional[List[str]] = None)
-
Serialize to dictionary containing primitive types, dictionaries, or iterables.
Notes
Before serialization, invoke ‘init’ for each class in class hierarchy that implements it, in the order from base to derived.
Args
data
- Object to serialize
select_fields
- Fields of data object which will be used for serialization. If None – use all fields.