Module: testing_context

Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from cl.runtime.backend.core.user_key import UserKey
from cl.runtime.context.context import Context
from cl.runtime.context.env_util import EnvUtil
from cl.runtime.db.dataset_util import DatasetUtil
from cl.runtime.records.class_info import ClassInfo
from cl.runtime.settings.context_settings import ContextSettings
from cl.runtime.settings.settings import Settings


@dataclass(slots=True, kw_only=True)
class TestingContext(Context):
    """
    Utilities for both pytest and unittest.

    Notes:
        - The name TestingContext was selected to avoid Test prefix and does not indicate it is for a specific package
        - This module not itself import pytest or unittest package
    """

    db_class: str | None = None
    """Override for the database class in module.ClassName format."""

    def __post_init__(self):
        """Configure fields that were not specified in constructor."""

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Confirm we are inside a test, error otherwise
            if not Settings.is_inside_test:
                raise RuntimeError(f"TestingContext created outside a test.")

            # Get test name in 'module.test_function' or 'module.TestClass.test_method' format inside a test
            context_settings = ContextSettings.instance()

            # For the test, env name is dot-delimited test module, class in snake_case (if any), and method or function
            env_name = EnvUtil.get_env_name()

            # Use test name in dot-delimited format for context_id unless specified by the caller
            if self.context_id is None:
                self.context_id = env_name

            # Set user to env name for unit testing
            self.user = UserKey(username=env_name)

            # TODO: Set log field here explicitly instead of relying on implicit detection of test environment
            log_type = ClassInfo.get_class_type(context_settings.log_class)
            self.log = log_type(log_id=self.context_id)

            # Use database class from settings unless this class provides an override
            if self.db_class is not None:
                db_class = self.db_class
            else:
                db_class = context_settings.db_class

            # Use 'temp' followed by context_id converted to semicolon-delimited format for db_id
            db_id = "temp;" + self.context_id.replace(".", ";")

            # Instantiate a new database object for every test
            db_type = ClassInfo.get_class_type(db_class)
            self.db = db_type(db_id=db_id)

            # Root dataset
            self.dataset = DatasetUtil.root()

    def __enter__(self):
        """Supports 'with' operator for resource disposal."""

        # Call '__enter__' method of base first
        Context.__enter__(self)

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Delete all existing data in temp database and drop DB in case it was not cleaned up
            # due to abnormal termination of the previous test run
            self.db.delete_all_and_drop_db()  # noqa

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Supports 'with' operator for resource disposal."""

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Delete all data in temp database and drop DB to clean up
            self.db.delete_all_and_drop_db()  # noqa

        # Call '__exit__' method of base last
        return Context.__exit__(self, exc_type, exc_val, exc_tb)

Classes

class TestingContext (*, context_id: str = None, user: UserKey = None, log: LogKey = None, db: DbKey = None, dataset: str = None, is_deserialized: bool = False, db_class: str | None = None)

Utilities for both pytest and unittest.

Notes

  • The name TestingContext was selected to avoid Test prefix and does not indicate it is for a specific package
  • This module not itself import pytest or unittest package
Expand source code
@dataclass(slots=True, kw_only=True)
class TestingContext(Context):
    """
    Utilities for both pytest and unittest.

    Notes:
        - The name TestingContext was selected to avoid Test prefix and does not indicate it is for a specific package
        - This module not itself import pytest or unittest package
    """

    db_class: str | None = None
    """Override for the database class in module.ClassName format."""

    def __post_init__(self):
        """Configure fields that were not specified in constructor."""

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Confirm we are inside a test, error otherwise
            if not Settings.is_inside_test:
                raise RuntimeError(f"TestingContext created outside a test.")

            # Get test name in 'module.test_function' or 'module.TestClass.test_method' format inside a test
            context_settings = ContextSettings.instance()

            # For the test, env name is dot-delimited test module, class in snake_case (if any), and method or function
            env_name = EnvUtil.get_env_name()

            # Use test name in dot-delimited format for context_id unless specified by the caller
            if self.context_id is None:
                self.context_id = env_name

            # Set user to env name for unit testing
            self.user = UserKey(username=env_name)

            # TODO: Set log field here explicitly instead of relying on implicit detection of test environment
            log_type = ClassInfo.get_class_type(context_settings.log_class)
            self.log = log_type(log_id=self.context_id)

            # Use database class from settings unless this class provides an override
            if self.db_class is not None:
                db_class = self.db_class
            else:
                db_class = context_settings.db_class

            # Use 'temp' followed by context_id converted to semicolon-delimited format for db_id
            db_id = "temp;" + self.context_id.replace(".", ";")

            # Instantiate a new database object for every test
            db_type = ClassInfo.get_class_type(db_class)
            self.db = db_type(db_id=db_id)

            # Root dataset
            self.dataset = DatasetUtil.root()

    def __enter__(self):
        """Supports 'with' operator for resource disposal."""

        # Call '__enter__' method of base first
        Context.__enter__(self)

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Delete all existing data in temp database and drop DB in case it was not cleaned up
            # due to abnormal termination of the previous test run
            self.db.delete_all_and_drop_db()  # noqa

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Supports 'with' operator for resource disposal."""

        # Do not execute this code on deserialized context instances (e.g. when they are passed to a task queue)
        if not self.is_deserialized:
            # Delete all data in temp database and drop DB to clean up
            self.db.delete_all_and_drop_db()  # noqa

        # Call '__exit__' method of base last
        return Context.__exit__(self, exc_type, exc_val, exc_tb)

Ancestors

Static methods

def current()

Inherited from: Context.current

Return the current context or None if not set.

def error_if_not_temp_db(db_id_or_database_name: str) -> None

Inherited from: Context.error_if_not_temp_db

Confirm that database id or database name matches temp_db_prefix, error otherwise.

def get_key_type() -> Type

Inherited from: Context.get_key_type

Return key type even when called from a record.

Fields

var context_id -> str

Inherited from: Context.context_id

Unique context identifier.

var dataset -> str

Inherited from: Context.dataset

Dataset of the context, ‘Context.current().dataset’ is used if not specified.

var db -> DbKey

Inherited from: Context.db

Database of the context, ‘Context.current().db’ is used if not specified.

var db_class -> str | None

Override for the database class in module.ClassName format.

var is_deserialized -> bool

Inherited from: Context.is_deserialized

Use this flag to determine if this context instance has been deserialized from data.

var log -> LogKey

Inherited from: Context.log

Log of the context, ‘Context.current().log’ is used if not specified.

var user -> UserKey

Inherited from: Context.user

Current user, ‘Context.current().user’ is used if not specified.

Methods

def delete_all_and_drop_db(self) -> None

Inherited from: Context.delete_all_and_drop_db

IMPORTANT: !!! DESTRUCTIVE – THIS WILL PERMANENTLY DELETE ALL RECORDS WITHOUT THE POSSIBILITY OF RECOVERY …

def delete_many(self, keys: Optional[Iterable[KeyProtocol]], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Context.delete_many

Delete records using an iterable of keys …

def delete_one(self, key_type: Type[~TKey], key: Union[~TKey, KeyProtocol, tuple, str, ForwardRef(None)], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Context.delete_one

Delete one record for the specified key type using its key in one of several possible formats …

def get_key(self) -> ContextKey

Inherited from: Context.get_key

Return a new key object whose fields populated from self, do not return self.

def get_logger(self, name: str) -> logging.Logger

Inherited from: Context.get_logger

Get logger for the specified name, invoke with name as the argument.

def init_all(self) -> None

Inherited from: Context.init_all

Invoke ‘init’ for each class in the order from base to derived, then validate against schema.

def load_all(self, record_type: Type[~TRecord], *, dataset: str | None = None, identity: str | None = None) -> Optional[Iterable[Optional[~TRecord]]]

Inherited from: Context.load_all

Load all records of the specified type and its subtypes (excludes other types in the same DB table) …

def load_filter(self, record_type: Type[~TRecord], filter_obj: ~TRecord, *, dataset: str | None = None, identity: str | None = None) -> Iterable[~TRecord]

Inherited from: Context.load_filter

Load records where values of those fields that are set in the filter match the filter …

def load_many(self, record_type: Type[~TRecord], records_or_keys: Optional[Iterable[Union[~TRecord, KeyProtocol, tuple, str, ForwardRef(None)]]], *, dataset: str | None = None, identity: str | None = None) -> Optional[Iterable[Optional[~TRecord]]]

Inherited from: Context.load_many

Load records using a list of keys (if a record is passed instead of a key, it is returned without DB lookup), the result must have the same order as …

def load_one(self, record_type: Type[~TRecord], record_or_key: Union[~TRecord, KeyProtocol, ForwardRef(None)], *, dataset: str | None = None, identity: str | None = None, is_key_optional: bool = False, is_record_optional: bool = False) -> Optional[~TRecord]

Inherited from: Context.load_one

Load a single record using a key (if a record is passed instead of a key, it is returned without DB lookup) …

def save_many(self, records: Iterable[RecordProtocol], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Context.save_many

Save records to storage …

def save_one(self, record: RecordProtocol | None, *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Context.save_one

Save records to storage …