Module: record_util

Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
from dataclasses import MISSING
from dataclasses import fields
from dataclasses import is_dataclass
from types import NoneType
from types import UnionType
from typing import Any
from typing import Iterable
from typing import List
from typing import Type
from typing import Union
from typing import get_args
from typing import get_origin
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.records.protocols import RecordProtocol
from cl.runtime.records.protocols import is_record


class RecordUtil:
    """Utilities for working with records."""

    @classmethod
    def init_all(cls, obj) -> None:
        """Invoke 'init' for each class in the order from base to derived, then validate against schema."""

        # Keep track of which init methods in class hierarchy were already called
        invoked = set()

        # Reverse the MRO to start from base to derived
        for class_ in reversed(obj.__class__.__mro__):
            class_init = getattr(class_, "init", None)
            if class_init is not None and (qualname := class_init.__qualname__) not in invoked:
                # Add qualname to invoked to prevent executing the same method twice
                invoked.add(qualname)
                # Invoke 'init' method of superclass if it exists, otherwise do nothing
                class_init(obj)

        # Perform validation against the schema only after all init methods are called
        cls.validate(obj)

    @classmethod
    def validate(cls, obj) -> None:
        """Validate against schema (invoked by init_all after all init methods are called)."""
        # TODO: Support other dataclass-like frameworks
        class_name = obj.__class__.__name__
        if is_dataclass(obj):
            for field in fields(obj):
                field_value = getattr(obj, field.name)
                if field_value is not None:
                    # Check that for the fields that have values, the values are of the right type
                    if not cls._is_instance(field_value, field.type):
                        field_type_name = cls._get_field_type_name(field.type)
                        value_type_name = type(field_value).__name__
                        if "member_descriptor" not in value_type_name:  # TODO(Roman): Remove when fixed
                            raise RuntimeError(
                                f"""Type mismatch for field '{field.name}' of class {class_name}.
Type in dataclass declaration: {field_type_name}
Type of the value: {type(field_value).__name__}
Note: In case of containers, type mismatch may be in one of the items.
"""
                            )
                else:
                    default_is_none = field.default is None
                    default_factory_is_missing = field.default_factory is MISSING
                    default_value_not_set = default_is_none and default_factory_is_missing
                    if default_value_not_set and not cls._is_optional(field.type):
                        # Error if a field is None but declared as required
                        raise UserError(f"Field '{field.name}' in class '{class_name}' is required but not set.")

    @classmethod
    def is_abstract(cls, record_type: Type) -> bool:
        """Return True if 'record_type' is abstract."""
        return bool(inspect.isabstract(record_type))

    @classmethod
    def get_non_abstract_descendants(cls, record_type: Type) -> List[Type]:
        """Find non-abstract descendants of 'record_type' to all levels and return the list of ClassName."""
        subclasses = record_type.__subclasses__()
        result = []
        for subclass in subclasses:
            # Recursively check subclasses
            result.extend(cls.get_non_abstract_descendants(subclass))
            # If the subclass is not abstract, add it to the list
            if not inspect.isabstract(subclass):
                result.append(subclass)
        return result

    @classmethod
    def _is_instance(cls, field_value, field_type) -> bool:

        origin = get_origin(field_type)
        args = get_args(field_type)

        if origin is None:
            # Not a generic type, consider the possible use of annotation
            if isinstance(field_type, type):
                return isinstance(field_value, field_type)
            elif isinstance(field_type, str):
                field_value_type_name = type(field_value).__name__
                return field_value_type_name == field_type
            else:
                raise RuntimeError(f"Field type {field_type} is neither a type nor a string.")
        elif origin in [UnionType, Union]:
            if field_value is None:
                return NoneType in args
            else:
                return any(cls._is_instance(field_value, arg) for arg in args)
        elif cls._is_instance(field_value, origin):
            # If the generic has type parameters, check them
            if args:
                if isinstance(field_value, list) and origin is list:
                    return all(cls._is_instance(item, args[0]) for item in field_value)
                elif isinstance(field_value, dict) and origin is dict:
                    return all(
                        isinstance(key, args[0]) and cls._is_instance(value, args[1])
                        for key, value in field_value.items()
                    )
        else:
            # Not an instance of the specified origin
            return False

    @classmethod
    def _is_optional(cls, field_type) -> bool:
        """Return true if None is an valid value for field_type."""
        # Check if the type is a union
        if get_origin(field_type) in [UnionType, Union]:
            # Check if NoneType is one of the arguments in the union
            return NoneType in get_args(field_type)
        else:
            # Type hint is not a union, the value cannot be None
            return False

    @classmethod
    def _get_field_type_name(cls, field_type):
        """Get the name of a type, including handling for Union types."""
        if get_origin(field_type) in [UnionType, Union]:
            return " | ".join(t.__name__ for t in get_args(field_type))
        else:
            return field_type.__name__

    @classmethod
    def sort_records_by_key(cls, records: Iterable[RecordProtocol]) -> List[RecordProtocol]:
        """Sort records by string key fields."""

        # TODO (Roman): Check string key fields in nested keys
        sort_records: Any = []
        for record in records:
            # TODO: Refactor to use a key serializer
            key_slots = record.get_key().__slots__ if is_record(record) else tuple()  # noqa
            str_key_values = [v for slot in key_slots if isinstance((v := getattr(record, slot)), str)]
            sort_key = ";".join(str_key_values)
            sort_records.append((sort_key, record))

        return [record for _, record in sorted(sort_records, key=lambda x: x[0])]

Classes

class RecordUtil

Utilities for working with records.

Expand source code
class RecordUtil:
    """Utilities for working with records."""

    @classmethod
    def init_all(cls, obj) -> None:
        """Invoke 'init' for each class in the order from base to derived, then validate against schema."""

        # Keep track of which init methods in class hierarchy were already called
        invoked = set()

        # Reverse the MRO to start from base to derived
        for class_ in reversed(obj.__class__.__mro__):
            class_init = getattr(class_, "init", None)
            if class_init is not None and (qualname := class_init.__qualname__) not in invoked:
                # Add qualname to invoked to prevent executing the same method twice
                invoked.add(qualname)
                # Invoke 'init' method of superclass if it exists, otherwise do nothing
                class_init(obj)

        # Perform validation against the schema only after all init methods are called
        cls.validate(obj)

    @classmethod
    def validate(cls, obj) -> None:
        """Validate against schema (invoked by init_all after all init methods are called)."""
        # TODO: Support other dataclass-like frameworks
        class_name = obj.__class__.__name__
        if is_dataclass(obj):
            for field in fields(obj):
                field_value = getattr(obj, field.name)
                if field_value is not None:
                    # Check that for the fields that have values, the values are of the right type
                    if not cls._is_instance(field_value, field.type):
                        field_type_name = cls._get_field_type_name(field.type)
                        value_type_name = type(field_value).__name__
                        if "member_descriptor" not in value_type_name:  # TODO(Roman): Remove when fixed
                            raise RuntimeError(
                                f"""Type mismatch for field '{field.name}' of class {class_name}.
Type in dataclass declaration: {field_type_name}
Type of the value: {type(field_value).__name__}
Note: In case of containers, type mismatch may be in one of the items.
"""
                            )
                else:
                    default_is_none = field.default is None
                    default_factory_is_missing = field.default_factory is MISSING
                    default_value_not_set = default_is_none and default_factory_is_missing
                    if default_value_not_set and not cls._is_optional(field.type):
                        # Error if a field is None but declared as required
                        raise UserError(f"Field '{field.name}' in class '{class_name}' is required but not set.")

    @classmethod
    def is_abstract(cls, record_type: Type) -> bool:
        """Return True if 'record_type' is abstract."""
        return bool(inspect.isabstract(record_type))

    @classmethod
    def get_non_abstract_descendants(cls, record_type: Type) -> List[Type]:
        """Find non-abstract descendants of 'record_type' to all levels and return the list of ClassName."""
        subclasses = record_type.__subclasses__()
        result = []
        for subclass in subclasses:
            # Recursively check subclasses
            result.extend(cls.get_non_abstract_descendants(subclass))
            # If the subclass is not abstract, add it to the list
            if not inspect.isabstract(subclass):
                result.append(subclass)
        return result

    @classmethod
    def _is_instance(cls, field_value, field_type) -> bool:

        origin = get_origin(field_type)
        args = get_args(field_type)

        if origin is None:
            # Not a generic type, consider the possible use of annotation
            if isinstance(field_type, type):
                return isinstance(field_value, field_type)
            elif isinstance(field_type, str):
                field_value_type_name = type(field_value).__name__
                return field_value_type_name == field_type
            else:
                raise RuntimeError(f"Field type {field_type} is neither a type nor a string.")
        elif origin in [UnionType, Union]:
            if field_value is None:
                return NoneType in args
            else:
                return any(cls._is_instance(field_value, arg) for arg in args)
        elif cls._is_instance(field_value, origin):
            # If the generic has type parameters, check them
            if args:
                if isinstance(field_value, list) and origin is list:
                    return all(cls._is_instance(item, args[0]) for item in field_value)
                elif isinstance(field_value, dict) and origin is dict:
                    return all(
                        isinstance(key, args[0]) and cls._is_instance(value, args[1])
                        for key, value in field_value.items()
                    )
        else:
            # Not an instance of the specified origin
            return False

    @classmethod
    def _is_optional(cls, field_type) -> bool:
        """Return true if None is an valid value for field_type."""
        # Check if the type is a union
        if get_origin(field_type) in [UnionType, Union]:
            # Check if NoneType is one of the arguments in the union
            return NoneType in get_args(field_type)
        else:
            # Type hint is not a union, the value cannot be None
            return False

    @classmethod
    def _get_field_type_name(cls, field_type):
        """Get the name of a type, including handling for Union types."""
        if get_origin(field_type) in [UnionType, Union]:
            return " | ".join(t.__name__ for t in get_args(field_type))
        else:
            return field_type.__name__

    @classmethod
    def sort_records_by_key(cls, records: Iterable[RecordProtocol]) -> List[RecordProtocol]:
        """Sort records by string key fields."""

        # TODO (Roman): Check string key fields in nested keys
        sort_records: Any = []
        for record in records:
            # TODO: Refactor to use a key serializer
            key_slots = record.get_key().__slots__ if is_record(record) else tuple()  # noqa
            str_key_values = [v for slot in key_slots if isinstance((v := getattr(record, slot)), str)]
            sort_key = ";".join(str_key_values)
            sort_records.append((sort_key, record))

        return [record for _, record in sorted(sort_records, key=lambda x: x[0])]

Static methods

def get_non_abstract_descendants(record_type: Type) -> List[Type]

Find non-abstract descendants of ‘record_type’ to all levels and return the list of ClassName.

def init_all(obj) -> None

Invoke ‘init’ for each class in the order from base to derived, then validate against schema.

def is_abstract(record_type: Type) -> bool

Return True if ‘record_type’ is abstract.

def sort_records_by_key(records: Iterable[RecordProtocol]) -> List[RecordProtocol]

Sort records by string key fields.

def validate(obj) -> None

Validate against schema (invoked by init_all after all init methods are called).