Module: sqlite_db

Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sqlite3
from collections import defaultdict
from dataclasses import dataclass
from itertools import groupby
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Tuple
from typing import Type
from cl.runtime.context.context import Context
from cl.runtime.db.db import Db
from cl.runtime.db.protocols import TKey
from cl.runtime.db.protocols import TRecord
from cl.runtime.db.sql.sqlite_schema_manager import SqliteSchemaManager
from cl.runtime.file.file_util import FileUtil
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.records.protocols import KeyProtocol
from cl.runtime.records.protocols import RecordProtocol
from cl.runtime.records.protocols import is_key
from cl.runtime.records.record_util import RecordUtil
from cl.runtime.schema.schema import Schema
from cl.runtime.serialization.flat_dict_serializer import FlatDictSerializer
from cl.runtime.settings.project_settings import ProjectSettings

_connection_dict: Dict[str, sqlite3.Connection] = {}
"""Dict of Connection instances with db_id key stored outside the class to avoid serialization."""

_schema_manager_dict: Dict[str, SqliteSchemaManager] = {}
"""Dict of SqliteSchemaManager instances with db_id key key stored outside the class to avoid serialization."""


def dict_factory(cursor, row):
    """sqlite3 row factory to return result as dictionary."""
    fields = [column[0] for column in cursor.description]
    return {key: value for key, value in zip(fields, row)}


@dataclass(slots=True, kw_only=True)
class SqliteDb(Db):
    """Sqlite database without dataset and mile wide table for inheritance."""

    def batch_size(self) -> int:
        pass

    @classmethod
    def _add_where_keys_in_clause(
        cls,
        sql_statement: str,
        key_fields: Tuple[str, ...],
        columns_mapping: Dict[str, str],
        keys_len: int,
    ) -> str:
        """
        Add "WHERE (key_field1, ...) IN ((value1_for_field1, ...), (value2_for_field1, ...), ...)" clause to
        sql_statement.
        """

        # if key fields isn't empty add WHERE clause
        if key_fields:
            value_places = ", ".join([f'({", ".join(["?"] * len(key_fields))})' for _ in range(keys_len)])
            key_column_str = ", ".join([f'"{columns_mapping[key]}"' for key in key_fields])

            # add WHERE clause to sql_statement
            sql_statement += f" WHERE ({key_column_str}) IN ({value_places})"

        return sql_statement

    @classmethod
    def _serialize_keys_to_flat_tuple(
        cls,
        keys: Iterable[KeyProtocol],
        key_fields: Tuple[str, ...],
        serializer,
    ) -> Tuple[Any, ...]:
        """
        Sequentially serialize key fields for each key in keys into a flat tuple of values.
        Expected all keys are of the same type for which key fields are specified.
        """

        return tuple(serializer.serialize_data(getattr(key, key_field)) for key in keys for key_field in key_fields)

    def load_one(
        self,
        record_type: Type[TRecord],
        record_or_key: TRecord | KeyProtocol | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
        is_key_optional: bool = False,
        is_record_optional: bool = False,
    ) -> TRecord | None:
        # Check for an empty key
        if record_or_key is None:
            if is_key_optional:
                return None
            else:
                raise UserError(f"Key is None when trying to load record type {record_type.__name__} from DB.")

        # Delegate to load_many
        result = next(iter(self.load_many(record_type, [record_or_key], dataset=dataset, identity=identity)))

        # Check if the record was not found
        if not is_record_optional and result is None:
            raise UserError(f"{record_type.__name__} record is not found for key {record_or_key}")
        return result

    def load_many(
        self,
        record_type: Type[TRecord],
        records_or_keys: Iterable[TRecord | KeyProtocol | tuple | str | None] | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord | None] | None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        # Use itertools.groupby to preserve the original order of records_or_keys
        # Group by key type and then by it is key or record, if records rather than keys return without lookup
        for key_type, records_or_keys_group in groupby(records_or_keys, lambda x: x.get_key_type() if x else None):
            # handle None records_or_keys
            if key_type is None:
                yield from records_or_keys_group
                continue

            for is_key_group, keys_group in groupby(records_or_keys_group, lambda x: is_key(x)):
                # return directly if input is record
                if not is_key_group:
                    yield from keys_group
                    continue

                table_name = schema_manager.table_name_for_type(key_type)

                # if keys_group don't support "in" or "len" operator convert it to tuple
                if not hasattr(keys_group, "__contains__") or not hasattr(keys_group, "__len__"):
                    keys_group = tuple(keys_group)

                # return None for all keys in group if table doesn't exist
                existing_tables = schema_manager.existing_tables()
                if table_name not in existing_tables:
                    yield from (None for _ in range(len(keys_group)))
                    continue

                key_fields = schema_manager.get_primary_keys(key_type)
                columns_mapping = schema_manager.get_columns_mapping(key_type)

                # if keys_group don't support "in" or "len" operator convert it to tuple
                sql_statement = f'SELECT * FROM "{table_name}"'
                sql_statement = self._add_where_keys_in_clause(
                    sql_statement, key_fields, columns_mapping, len(keys_group)
                )
                sql_statement += ";"

                # serialize keys to tuple
                query_values = self._serialize_keys_to_flat_tuple(keys_group, key_fields, serializer)

                cursor = self._get_connection().cursor()
                cursor.execute(sql_statement, query_values)

                reversed_columns_mapping = {v: k for k, v in columns_mapping.items()}

                # TODO (Roman): investigate performance impact from this ordering approach
                # bulk load from db returns records in any order so we need to check all records in group before return
                # collect db result to dictionary to return it according to input keys order
                result = {}
                for data in cursor.fetchall():
                    # TODO (Roman): select only needed columns on db side.
                    data = {reversed_columns_mapping[k]: v for k, v in data.items() if v is not None}
                    deserialized_data = serializer.deserialize_data(data)

                    # TODO (Roman): make key hashable and remove conversion of key to str
                    result[str(deserialized_data.get_key())] = deserialized_data

                # yield records according to input keys order
                for key in keys_group:
                    yield result.get(str(key))

    def load_all(
        self,
        record_type: Type[TRecord],
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord | None] | None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        table_name: str = schema_manager.table_name_for_type(record_type)

        # if table doesn't exist return empty list
        if table_name not in schema_manager.existing_tables():
            return list()

        # get subtypes for record_type and use them in match condition
        subtype_names = tuple(t.__name__ for t in Schema.get_type_successors(record_type))
        value_placeholders = ", ".join(["?"] * len(subtype_names))
        sql_statement = f'SELECT * FROM "{table_name}" WHERE _type in ({value_placeholders});'

        reversed_columns_mapping = {
            v: k for k, v in schema_manager.get_columns_mapping(record_type.get_key_type()).items()
        }

        cursor = self._get_connection().cursor()
        cursor.execute(sql_statement, subtype_names)

        # TODO: Implement sort in query and restore yield to support large collections
        result = []
        for data in cursor.fetchall():
            # TODO (Roman): Select only needed columns on db side.
            data = {reversed_columns_mapping[k]: v for k, v in data.items() if v is not None}
            result.append(serializer.deserialize_data(data))

        return RecordUtil.sort_records_by_key(result)

    def load_filter(
        self,
        record_type: Type[TRecord],
        filter_obj: TRecord,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord]:
        raise NotImplementedError()

    def save_one(
        self,
        record: RecordProtocol | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        return self.save_many([record], dataset=dataset, identity=identity)

    def save_many(
        self,
        records: Iterable[RecordProtocol],
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:

        # Call on_save if defined
        [
            record.on_save()
            for record in records  # TODO: Refactor on_save
            if record is not None and hasattr(record, "on_save")
        ]

        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        grouped_records = defaultdict(list)

        # TODO (Roman): improve grouping
        for record in records:
            grouped_records[record.get_key_type()].append(record)

        for key_type, records_group in grouped_records.items():
            # serialize records
            serialized_records = [serializer.serialize_data(rec, is_root=True) for rec in records_group]

            # get maximum set of fields from records
            all_fields = list({k for rec in serialized_records for k in rec.keys()})

            # fill sql_values with ordered values from serialized records
            # if field isn't in some records - fill with None
            sql_values = tuple(
                serialized_record[k] if k in serialized_record else None
                for serialized_record in serialized_records
                for k in all_fields
            )

            columns_mapping = schema_manager.get_columns_mapping(key_type)
            quoted_columns = [f'"{columns_mapping[field]}"' for field in all_fields]
            columns_str = ", ".join(quoted_columns)

            value_placeholders = ", ".join([f"({', '.join(['?']*len(all_fields))})" for _ in range(len(records_group))])

            table_name = schema_manager.table_name_for_type(key_type)

            primary_keys = [columns_mapping[primary_key] for primary_key in schema_manager.get_primary_keys(key_type)]

            schema_manager.create_table(
                table_name, columns_mapping.values(), if_not_exists=True, primary_keys=primary_keys
            )

            sql_statement = f'REPLACE INTO "{table_name}" ({columns_str}) VALUES {value_placeholders};'

            if not primary_keys:
                # TODO (Roman): this is a workaround for handling singleton records.
                #  Since they don't have primary keys, we can't automatically replace existing records.
                #  So this code just deletes the existing records before saving.
                #  As a possible solution, we can introduce some mandatory primary key that isn't based on the
                #  key fields.
                self.delete_many((rec.get_key() for rec in records_group))

            connection = self._get_connection()
            cursor = connection.cursor()
            cursor.execute(sql_statement, sql_values)

            connection.commit()

    def delete_one(
        self,
        key_type: Type[TKey],
        key: TKey | KeyProtocol | tuple | str | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        # TODO (Yauheni): Add implementation independent from delete_many()
        self.delete_many([key], dataset=dataset, identity=identity)

    def delete_many(
        self,
        keys: Iterable[KeyProtocol] | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        # TODO (Roman): improve grouping
        grouped_keys = defaultdict(list)
        for key in keys:
            grouped_keys[key.get_key_type()].append(key)

        for key_type, keys_group in grouped_keys.items():
            table_name = schema_manager.table_name_for_type(key_type)

            existing_tables = schema_manager.existing_tables()
            if table_name not in existing_tables:
                continue

            key_fields = schema_manager.get_primary_keys(key_type)
            columns_mapping = schema_manager.get_columns_mapping(key_type)

            # if keys_group don't support "in" or "len" operator convert it to tuple
            if not hasattr(keys_group, "__contains__") or not hasattr(keys_group, "__len__"):
                keys_group = tuple(keys_group)

            # construct sql_statement with placeholders for values
            sql_statement = f'DELETE FROM "{table_name}"'
            sql_statement = self._add_where_keys_in_clause(sql_statement, key_fields, columns_mapping, len(keys_group))
            sql_statement += ";"

            # serialize keys to tuple
            query_values = self._serialize_keys_to_flat_tuple(keys_group, key_fields, serializer)

            # perform delete query
            connection = self._get_connection()
            cursor = connection.cursor()
            cursor.execute(sql_statement, query_values)
            connection.commit()

    def delete_all_and_drop_db(self) -> None:
        # Check that db_id matches temp_db_prefix
        Context.error_if_not_temp_db(self.db_id)

        # Close connection
        self.close_connection()

        # Check that filename also matches temp_db_prefix. It should normally match db_id
        # we already checked, but given the critical importance of this check will check db_filename
        # as well in case this approach changes later.
        db_file_path = self._get_db_file()
        db_filename = os.path.basename(db_file_path)
        Context.error_if_not_temp_db(db_filename)

        # Delete database file if exists, all checks gave been performed
        if os.path.exists(db_file_path):
            os.remove(db_file_path)

    def close_connection(self) -> None:
        if (connection := _connection_dict.get(self.db_id, None)) is not None:
            # Close connection
            connection.close()
            # Remove from dictionary so connection can be reopened on next access
            del _connection_dict[self.db_id]
            del _schema_manager_dict[self.db_id]
            pass

    def _get_connection(self) -> sqlite3.Connection:
        """Get PyMongo database object."""
        if (connection := _connection_dict.get(self.db_id, None)) is None:
            # TODO: Implement dispose logic
            db_file = self._get_db_file()
            connection = sqlite3.connect(db_file, check_same_thread=False)
            connection.row_factory = dict_factory
            _connection_dict[self.db_id] = connection
        return connection

    def _get_schema_manager(self) -> SqliteSchemaManager:
        """Get PyMongo database object."""
        if (result := _schema_manager_dict.get(self.db_id, None)) is None:
            # TODO: Implement dispose logic
            connection = self._get_connection()
            result = SqliteSchemaManager(sqlite_connection=connection)
            _schema_manager_dict[self.db_id] = result
        return result

    def _get_db_file(self) -> str:
        """Get database file path from db_id, applying the appropriate formatting conventions."""

        # Check that db_id is a valid filename
        filename = self.db_id
        FileUtil.check_valid_filename(filename)

        # Get dir for database
        db_dir = ProjectSettings.get_databases_dir()

        result = os.path.join(db_dir, f"{filename}.sqlite")
        return result

    def is_empty(self) -> bool:
        """Return True if the database has no tables or all tables are empty."""
        connection = self._get_connection()
        cursor = connection.cursor()

        # Check if there are any tables in the SQLite database
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()

        # If no tables are present, the database is empty
        if not tables:
            return True

        # Check if all tables are empty
        for table_name in tables:
            table_name = table_name["name"]
            cursor.execute(f'SELECT COUNT(*) FROM "{table_name}";')
            count = cursor.fetchone()["COUNT(*)"]

            # If any table has data, the database is not empty
            if count > 0:
                return False

        return True

Functions

def dict_factory(cursor, row)

sqlite3 row factory to return result as dictionary.

Classes

class SqliteDb (*, db_id: str = None)

Sqlite database without dataset and mile wide table for inheritance.

Expand source code
@dataclass(slots=True, kw_only=True)
class SqliteDb(Db):
    """Sqlite database without dataset and mile wide table for inheritance."""

    def batch_size(self) -> int:
        pass

    @classmethod
    def _add_where_keys_in_clause(
        cls,
        sql_statement: str,
        key_fields: Tuple[str, ...],
        columns_mapping: Dict[str, str],
        keys_len: int,
    ) -> str:
        """
        Add "WHERE (key_field1, ...) IN ((value1_for_field1, ...), (value2_for_field1, ...), ...)" clause to
        sql_statement.
        """

        # if key fields isn't empty add WHERE clause
        if key_fields:
            value_places = ", ".join([f'({", ".join(["?"] * len(key_fields))})' for _ in range(keys_len)])
            key_column_str = ", ".join([f'"{columns_mapping[key]}"' for key in key_fields])

            # add WHERE clause to sql_statement
            sql_statement += f" WHERE ({key_column_str}) IN ({value_places})"

        return sql_statement

    @classmethod
    def _serialize_keys_to_flat_tuple(
        cls,
        keys: Iterable[KeyProtocol],
        key_fields: Tuple[str, ...],
        serializer,
    ) -> Tuple[Any, ...]:
        """
        Sequentially serialize key fields for each key in keys into a flat tuple of values.
        Expected all keys are of the same type for which key fields are specified.
        """

        return tuple(serializer.serialize_data(getattr(key, key_field)) for key in keys for key_field in key_fields)

    def load_one(
        self,
        record_type: Type[TRecord],
        record_or_key: TRecord | KeyProtocol | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
        is_key_optional: bool = False,
        is_record_optional: bool = False,
    ) -> TRecord | None:
        # Check for an empty key
        if record_or_key is None:
            if is_key_optional:
                return None
            else:
                raise UserError(f"Key is None when trying to load record type {record_type.__name__} from DB.")

        # Delegate to load_many
        result = next(iter(self.load_many(record_type, [record_or_key], dataset=dataset, identity=identity)))

        # Check if the record was not found
        if not is_record_optional and result is None:
            raise UserError(f"{record_type.__name__} record is not found for key {record_or_key}")
        return result

    def load_many(
        self,
        record_type: Type[TRecord],
        records_or_keys: Iterable[TRecord | KeyProtocol | tuple | str | None] | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord | None] | None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        # Use itertools.groupby to preserve the original order of records_or_keys
        # Group by key type and then by it is key or record, if records rather than keys return without lookup
        for key_type, records_or_keys_group in groupby(records_or_keys, lambda x: x.get_key_type() if x else None):
            # handle None records_or_keys
            if key_type is None:
                yield from records_or_keys_group
                continue

            for is_key_group, keys_group in groupby(records_or_keys_group, lambda x: is_key(x)):
                # return directly if input is record
                if not is_key_group:
                    yield from keys_group
                    continue

                table_name = schema_manager.table_name_for_type(key_type)

                # if keys_group don't support "in" or "len" operator convert it to tuple
                if not hasattr(keys_group, "__contains__") or not hasattr(keys_group, "__len__"):
                    keys_group = tuple(keys_group)

                # return None for all keys in group if table doesn't exist
                existing_tables = schema_manager.existing_tables()
                if table_name not in existing_tables:
                    yield from (None for _ in range(len(keys_group)))
                    continue

                key_fields = schema_manager.get_primary_keys(key_type)
                columns_mapping = schema_manager.get_columns_mapping(key_type)

                # if keys_group don't support "in" or "len" operator convert it to tuple
                sql_statement = f'SELECT * FROM "{table_name}"'
                sql_statement = self._add_where_keys_in_clause(
                    sql_statement, key_fields, columns_mapping, len(keys_group)
                )
                sql_statement += ";"

                # serialize keys to tuple
                query_values = self._serialize_keys_to_flat_tuple(keys_group, key_fields, serializer)

                cursor = self._get_connection().cursor()
                cursor.execute(sql_statement, query_values)

                reversed_columns_mapping = {v: k for k, v in columns_mapping.items()}

                # TODO (Roman): investigate performance impact from this ordering approach
                # bulk load from db returns records in any order so we need to check all records in group before return
                # collect db result to dictionary to return it according to input keys order
                result = {}
                for data in cursor.fetchall():
                    # TODO (Roman): select only needed columns on db side.
                    data = {reversed_columns_mapping[k]: v for k, v in data.items() if v is not None}
                    deserialized_data = serializer.deserialize_data(data)

                    # TODO (Roman): make key hashable and remove conversion of key to str
                    result[str(deserialized_data.get_key())] = deserialized_data

                # yield records according to input keys order
                for key in keys_group:
                    yield result.get(str(key))

    def load_all(
        self,
        record_type: Type[TRecord],
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord | None] | None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        table_name: str = schema_manager.table_name_for_type(record_type)

        # if table doesn't exist return empty list
        if table_name not in schema_manager.existing_tables():
            return list()

        # get subtypes for record_type and use them in match condition
        subtype_names = tuple(t.__name__ for t in Schema.get_type_successors(record_type))
        value_placeholders = ", ".join(["?"] * len(subtype_names))
        sql_statement = f'SELECT * FROM "{table_name}" WHERE _type in ({value_placeholders});'

        reversed_columns_mapping = {
            v: k for k, v in schema_manager.get_columns_mapping(record_type.get_key_type()).items()
        }

        cursor = self._get_connection().cursor()
        cursor.execute(sql_statement, subtype_names)

        # TODO: Implement sort in query and restore yield to support large collections
        result = []
        for data in cursor.fetchall():
            # TODO (Roman): Select only needed columns on db side.
            data = {reversed_columns_mapping[k]: v for k, v in data.items() if v is not None}
            result.append(serializer.deserialize_data(data))

        return RecordUtil.sort_records_by_key(result)

    def load_filter(
        self,
        record_type: Type[TRecord],
        filter_obj: TRecord,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> Iterable[TRecord]:
        raise NotImplementedError()

    def save_one(
        self,
        record: RecordProtocol | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        return self.save_many([record], dataset=dataset, identity=identity)

    def save_many(
        self,
        records: Iterable[RecordProtocol],
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:

        # Call on_save if defined
        [
            record.on_save()
            for record in records  # TODO: Refactor on_save
            if record is not None and hasattr(record, "on_save")
        ]

        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        grouped_records = defaultdict(list)

        # TODO (Roman): improve grouping
        for record in records:
            grouped_records[record.get_key_type()].append(record)

        for key_type, records_group in grouped_records.items():
            # serialize records
            serialized_records = [serializer.serialize_data(rec, is_root=True) for rec in records_group]

            # get maximum set of fields from records
            all_fields = list({k for rec in serialized_records for k in rec.keys()})

            # fill sql_values with ordered values from serialized records
            # if field isn't in some records - fill with None
            sql_values = tuple(
                serialized_record[k] if k in serialized_record else None
                for serialized_record in serialized_records
                for k in all_fields
            )

            columns_mapping = schema_manager.get_columns_mapping(key_type)
            quoted_columns = [f'"{columns_mapping[field]}"' for field in all_fields]
            columns_str = ", ".join(quoted_columns)

            value_placeholders = ", ".join([f"({', '.join(['?']*len(all_fields))})" for _ in range(len(records_group))])

            table_name = schema_manager.table_name_for_type(key_type)

            primary_keys = [columns_mapping[primary_key] for primary_key in schema_manager.get_primary_keys(key_type)]

            schema_manager.create_table(
                table_name, columns_mapping.values(), if_not_exists=True, primary_keys=primary_keys
            )

            sql_statement = f'REPLACE INTO "{table_name}" ({columns_str}) VALUES {value_placeholders};'

            if not primary_keys:
                # TODO (Roman): this is a workaround for handling singleton records.
                #  Since they don't have primary keys, we can't automatically replace existing records.
                #  So this code just deletes the existing records before saving.
                #  As a possible solution, we can introduce some mandatory primary key that isn't based on the
                #  key fields.
                self.delete_many((rec.get_key() for rec in records_group))

            connection = self._get_connection()
            cursor = connection.cursor()
            cursor.execute(sql_statement, sql_values)

            connection.commit()

    def delete_one(
        self,
        key_type: Type[TKey],
        key: TKey | KeyProtocol | tuple | str | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        # TODO (Yauheni): Add implementation independent from delete_many()
        self.delete_many([key], dataset=dataset, identity=identity)

    def delete_many(
        self,
        keys: Iterable[KeyProtocol] | None,
        *,
        dataset: str | None = None,
        identity: str | None = None,
    ) -> None:
        serializer = FlatDictSerializer()
        schema_manager = self._get_schema_manager()

        # TODO (Roman): improve grouping
        grouped_keys = defaultdict(list)
        for key in keys:
            grouped_keys[key.get_key_type()].append(key)

        for key_type, keys_group in grouped_keys.items():
            table_name = schema_manager.table_name_for_type(key_type)

            existing_tables = schema_manager.existing_tables()
            if table_name not in existing_tables:
                continue

            key_fields = schema_manager.get_primary_keys(key_type)
            columns_mapping = schema_manager.get_columns_mapping(key_type)

            # if keys_group don't support "in" or "len" operator convert it to tuple
            if not hasattr(keys_group, "__contains__") or not hasattr(keys_group, "__len__"):
                keys_group = tuple(keys_group)

            # construct sql_statement with placeholders for values
            sql_statement = f'DELETE FROM "{table_name}"'
            sql_statement = self._add_where_keys_in_clause(sql_statement, key_fields, columns_mapping, len(keys_group))
            sql_statement += ";"

            # serialize keys to tuple
            query_values = self._serialize_keys_to_flat_tuple(keys_group, key_fields, serializer)

            # perform delete query
            connection = self._get_connection()
            cursor = connection.cursor()
            cursor.execute(sql_statement, query_values)
            connection.commit()

    def delete_all_and_drop_db(self) -> None:
        # Check that db_id matches temp_db_prefix
        Context.error_if_not_temp_db(self.db_id)

        # Close connection
        self.close_connection()

        # Check that filename also matches temp_db_prefix. It should normally match db_id
        # we already checked, but given the critical importance of this check will check db_filename
        # as well in case this approach changes later.
        db_file_path = self._get_db_file()
        db_filename = os.path.basename(db_file_path)
        Context.error_if_not_temp_db(db_filename)

        # Delete database file if exists, all checks gave been performed
        if os.path.exists(db_file_path):
            os.remove(db_file_path)

    def close_connection(self) -> None:
        if (connection := _connection_dict.get(self.db_id, None)) is not None:
            # Close connection
            connection.close()
            # Remove from dictionary so connection can be reopened on next access
            del _connection_dict[self.db_id]
            del _schema_manager_dict[self.db_id]
            pass

    def _get_connection(self) -> sqlite3.Connection:
        """Get PyMongo database object."""
        if (connection := _connection_dict.get(self.db_id, None)) is None:
            # TODO: Implement dispose logic
            db_file = self._get_db_file()
            connection = sqlite3.connect(db_file, check_same_thread=False)
            connection.row_factory = dict_factory
            _connection_dict[self.db_id] = connection
        return connection

    def _get_schema_manager(self) -> SqliteSchemaManager:
        """Get PyMongo database object."""
        if (result := _schema_manager_dict.get(self.db_id, None)) is None:
            # TODO: Implement dispose logic
            connection = self._get_connection()
            result = SqliteSchemaManager(sqlite_connection=connection)
            _schema_manager_dict[self.db_id] = result
        return result

    def _get_db_file(self) -> str:
        """Get database file path from db_id, applying the appropriate formatting conventions."""

        # Check that db_id is a valid filename
        filename = self.db_id
        FileUtil.check_valid_filename(filename)

        # Get dir for database
        db_dir = ProjectSettings.get_databases_dir()

        result = os.path.join(db_dir, f"{filename}.sqlite")
        return result

    def is_empty(self) -> bool:
        """Return True if the database has no tables or all tables are empty."""
        connection = self._get_connection()
        cursor = connection.cursor()

        # Check if there are any tables in the SQLite database
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        tables = cursor.fetchall()

        # If no tables are present, the database is empty
        if not tables:
            return True

        # Check if all tables are empty
        for table_name in tables:
            table_name = table_name["name"]
            cursor.execute(f'SELECT COUNT(*) FROM "{table_name}";')
            count = cursor.fetchone()["COUNT(*)"]

            # If any table has data, the database is not empty
            if count > 0:
                return False

        return True

Ancestors

Static methods

def default() -> Db

Inherited from: Db.default

Default database is initialized from settings and cannot be modified in code.

def get_key_type() -> Type

Inherited from: Db.get_key_type

Return key type even when called from a record.

Fields

var db_id -> str

Inherited from: Db.db_id

Unique database identifier.

Methods

def batch_size(self) -> int
def close_connection(self) -> None

Inherited from: Db.close_connection

Close database connection to releasing resource locks.

def delete_all_and_drop_db(self) -> None

Inherited from: Db.delete_all_and_drop_db

IMPORTANT: !!! DESTRUCTIVE – THIS WILL PERMANENTLY DELETE ALL RECORDS WITHOUT THE POSSIBILITY OF RECOVERY …

def delete_many(self, keys: Optional[Iterable[KeyProtocol]], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Db.delete_many

Delete records using an iterable of keys …

def delete_one(self, key_type: Type[~TKey], key: Union[~TKey, KeyProtocol, tuple, str, ForwardRef(None)], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Db.delete_one

Delete one record for the specified key type using its key in one of several possible formats …

def get_key(self) -> DbKey

Inherited from: Db.get_key

Return a new key object whose fields populated from self, do not return self.

def init_all(self) -> None

Inherited from: Db.init_all

Invoke ‘init’ for each class in the order from base to derived, then validate against schema.

def is_empty(self) -> bool

Return True if the database has no tables or all tables are empty.

def load_all(self, record_type: Type[~TRecord], *, dataset: str | None = None, identity: str | None = None) -> Optional[Iterable[Optional[~TRecord]]]

Inherited from: Db.load_all

Load all records of the specified type and its subtypes (excludes other types in the same DB table) …

def load_filter(self, record_type: Type[~TRecord], filter_obj: ~TRecord, *, dataset: str | None = None, identity: str | None = None) -> Iterable[~TRecord]

Inherited from: Db.load_filter

Load records where values of those fields that are set in the filter match the filter …

def load_many(self, record_type: Type[~TRecord], records_or_keys: Optional[Iterable[Union[~TRecord, KeyProtocol, tuple, str, ForwardRef(None)]]], *, dataset: str | None = None, identity: str | None = None) -> Optional[Iterable[Optional[~TRecord]]]

Inherited from: Db.load_many

Load records using a list of keys (if a record is passed instead of a key, it is returned without DB lookup), the result must have the same order as …

def load_one(self, record_type: Type[~TRecord], record_or_key: Union[~TRecord, KeyProtocol, ForwardRef(None)], *, dataset: str | None = None, identity: str | None = None, is_key_optional: bool = False, is_record_optional: bool = False) -> Optional[~TRecord]

Inherited from: Db.load_one

Load a single record using a key (if a record is passed instead of a key, it is returned without DB lookup) …

def save_many(self, records: Iterable[RecordProtocol], *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Db.save_many

Save records to storage …

def save_one(self, record: RecordProtocol | None, *, dataset: str | None = None, identity: str | None = None) -> None

Inherited from: Db.save_one

Save records to storage …