Module: sqlite_schema_manager
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sqlite3
from dataclasses import dataclass
from inspect import isclass
from typing import Dict
from typing import Iterable
from typing import List
from typing import Tuple
from typing import Type
from typing import cast
from cl.runtime.primitive.case_util import CaseUtil
from cl.runtime.records.protocols import KeyProtocol
from cl.runtime.schema.schema import Schema
@dataclass(slots=True, kw_only=True)
class SqliteSchemaManager:
"""Class to manage the sqlite schema (table names, columns mapping etc.)."""
sqlite_connection: sqlite3.Connection = None
"""Sqlite connection."""
pascalize_column_names: bool = False
"""If True - convert column names to pascal case."""
add_class_to_column_names: bool = True
"""If True - class name will be added to the column name in format ClassName.field_name."""
def create_table(
self,
table_name: str,
columns: Iterable[str],
if_not_exists: bool = True,
primary_keys: List[str] | None = None,
) -> None:
"""
Create sqlite table with given name and columns.
No need to specify column types because sqlite supports dynamic typing.
Mile wide table contains columns for all subtypes.
"""
if_not_exists_part: str = " IF NOT EXISTS" if if_not_exists else ""
columns_str: str = '"' + '", "'.join(columns) + '"'
# construct final create table statement
create_table_statement: str = f"CREATE TABLE{if_not_exists_part} {table_name} ({columns_str});"
# execute create table statement
cursor = self.sqlite_connection.cursor()
cursor.execute(create_table_statement)
if primary_keys:
keys_str = ", ".join([f'"{key}"' for key in primary_keys])
# Make index name based on table name to be unique within database
index_name = f"{table_name}_key_index"
create_unique_index_statement = (
f'CREATE UNIQUE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" ({keys_str});'
)
cursor.execute(create_unique_index_statement)
self.sqlite_connection.commit()
def delete_table_by_name(self, name: str, if_exists: bool = True) -> None:
"""Delete table in db."""
cursor = self.sqlite_connection.cursor()
if_exists_part: str = " IF EXISTS" if if_exists else ""
cursor.execute(f"DROP TABLE {if_exists_part} '{name}';")
self.sqlite_connection.commit()
def table_name_for_type(self, type_: Type) -> str:
"""Return table name for the given type."""
# Return key type name inclusive of Key suffix
key_type = cast(KeyProtocol, type_).get_key_type()
return key_type.__name__ # TODO: Also include module
def existing_tables(self) -> List[str]:
"""Return existing tables in db."""
cursor = self.sqlite_connection.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
return [select_res["name"] for select_res in cursor.fetchall()]
def _get_type_fields(self, type_: Type) -> Dict[str, Type]: # TODO: Consolidate this and similar code in Schema
"""Return field name and type of annotation based type declaration."""
return type_.__annotations__
# TODO (Roman): make cached but only for key types
def get_columns_mapping(self, type_: Type) -> Dict[str, str]:
"""Collect all types in hierarchy and check type conflicts for fields with the same name."""
types_in_hierarchy = Schema.get_types_in_hierarchy(type_)
key_type = cast(KeyProtocol, type_).get_key_type()
# Get table name inclusive of Key suffix if present
key_fields_class_name: str = key_type.__name__ # TODO: Also include module
# Get fields
key_fields = self._get_type_fields(key_type)
# {field_name: (subclass_name, field_type)}
all_fields: Dict[str, Tuple[str, Type]] = {
key_field_name: (key_fields_class_name, key_field_type)
for key_field_name, key_field_type in key_fields.items()
}
for type_ in types_in_hierarchy:
fields = self._get_type_fields(type_).items()
for field_name, field_type in fields:
existing_field = all_fields.get(field_name)
if existing_field is not None:
if not isclass(field_type):
# TODO (Roman): support union validation in schema
# Skip type checking for fields with non-class annotation (e.g. Union)
continue
# Check if fields with the same name have compatible type
if not issubclass(field_type, existing_field[1]):
raise TypeError(
f"Field {field_name}: {field_type} of class {type_.__name__} conflicts with the same field "
f"{field_name}: {existing_field[1]} in base class {existing_field[0]}"
)
else:
all_fields[field_name] = (type_.__name__, field_type)
columns_mapping = {"_type": "_type"}
for field_name, (class_name, _) in all_fields.items():
field_name = field_name if not self.pascalize_column_names else CaseUtil.snake_to_pascal_case(field_name)
column_name = (
f"{class_name}." if self.add_class_to_column_names and class_name is not None else ""
) + field_name
columns_mapping[field_name] = column_name
return columns_mapping
def get_primary_keys(self, type_: Type) -> Tuple[str, ...]:
"""Return list of primary key fields."""
key_type = cast(KeyProtocol, type_).get_key_type()
key_fields = self._get_type_fields(key_type)
return tuple(key_fields.keys())
Classes
class SqliteSchemaManager (*, sqlite_connection: sqlite3.Connection = None, pascalize_column_names: bool = False, add_class_to_column_names: bool = True)
-
Class to manage the sqlite schema (table names, columns mapping etc.).
Expand source code
@dataclass(slots=True, kw_only=True) class SqliteSchemaManager: """Class to manage the sqlite schema (table names, columns mapping etc.).""" sqlite_connection: sqlite3.Connection = None """Sqlite connection.""" pascalize_column_names: bool = False """If True - convert column names to pascal case.""" add_class_to_column_names: bool = True """If True - class name will be added to the column name in format ClassName.field_name.""" def create_table( self, table_name: str, columns: Iterable[str], if_not_exists: bool = True, primary_keys: List[str] | None = None, ) -> None: """ Create sqlite table with given name and columns. No need to specify column types because sqlite supports dynamic typing. Mile wide table contains columns for all subtypes. """ if_not_exists_part: str = " IF NOT EXISTS" if if_not_exists else "" columns_str: str = '"' + '", "'.join(columns) + '"' # construct final create table statement create_table_statement: str = f"CREATE TABLE{if_not_exists_part} {table_name} ({columns_str});" # execute create table statement cursor = self.sqlite_connection.cursor() cursor.execute(create_table_statement) if primary_keys: keys_str = ", ".join([f'"{key}"' for key in primary_keys]) # Make index name based on table name to be unique within database index_name = f"{table_name}_key_index" create_unique_index_statement = ( f'CREATE UNIQUE INDEX IF NOT EXISTS "{index_name}" ON "{table_name}" ({keys_str});' ) cursor.execute(create_unique_index_statement) self.sqlite_connection.commit() def delete_table_by_name(self, name: str, if_exists: bool = True) -> None: """Delete table in db.""" cursor = self.sqlite_connection.cursor() if_exists_part: str = " IF EXISTS" if if_exists else "" cursor.execute(f"DROP TABLE {if_exists_part} '{name}';") self.sqlite_connection.commit() def table_name_for_type(self, type_: Type) -> str: """Return table name for the given type.""" # Return key type name inclusive of Key suffix key_type = cast(KeyProtocol, type_).get_key_type() return key_type.__name__ # TODO: Also include module def existing_tables(self) -> List[str]: """Return existing tables in db.""" cursor = self.sqlite_connection.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return [select_res["name"] for select_res in cursor.fetchall()] def _get_type_fields(self, type_: Type) -> Dict[str, Type]: # TODO: Consolidate this and similar code in Schema """Return field name and type of annotation based type declaration.""" return type_.__annotations__ # TODO (Roman): make cached but only for key types def get_columns_mapping(self, type_: Type) -> Dict[str, str]: """Collect all types in hierarchy and check type conflicts for fields with the same name.""" types_in_hierarchy = Schema.get_types_in_hierarchy(type_) key_type = cast(KeyProtocol, type_).get_key_type() # Get table name inclusive of Key suffix if present key_fields_class_name: str = key_type.__name__ # TODO: Also include module # Get fields key_fields = self._get_type_fields(key_type) # {field_name: (subclass_name, field_type)} all_fields: Dict[str, Tuple[str, Type]] = { key_field_name: (key_fields_class_name, key_field_type) for key_field_name, key_field_type in key_fields.items() } for type_ in types_in_hierarchy: fields = self._get_type_fields(type_).items() for field_name, field_type in fields: existing_field = all_fields.get(field_name) if existing_field is not None: if not isclass(field_type): # TODO (Roman): support union validation in schema # Skip type checking for fields with non-class annotation (e.g. Union) continue # Check if fields with the same name have compatible type if not issubclass(field_type, existing_field[1]): raise TypeError( f"Field {field_name}: {field_type} of class {type_.__name__} conflicts with the same field " f"{field_name}: {existing_field[1]} in base class {existing_field[0]}" ) else: all_fields[field_name] = (type_.__name__, field_type) columns_mapping = {"_type": "_type"} for field_name, (class_name, _) in all_fields.items(): field_name = field_name if not self.pascalize_column_names else CaseUtil.snake_to_pascal_case(field_name) column_name = ( f"{class_name}." if self.add_class_to_column_names and class_name is not None else "" ) + field_name columns_mapping[field_name] = column_name return columns_mapping def get_primary_keys(self, type_: Type) -> Tuple[str, ...]: """Return list of primary key fields.""" key_type = cast(KeyProtocol, type_).get_key_type() key_fields = self._get_type_fields(key_type) return tuple(key_fields.keys())
Fields
var add_class_to_column_names -> bool
-
If True – class name will be added to the column name in format ClassName.field_name.
var pascalize_column_names -> bool
-
If True – convert column names to pascal case.
var sqlite_connection -> sqlite3.Connection
-
Sqlite connection.
Methods
def create_table(self, table_name: str, columns: Iterable[str], if_not_exists: bool = True, primary_keys: Optional[List[str]] = None) -> None
-
Create sqlite table with given name and columns.
No need to specify column types because sqlite supports dynamic typing. Mile wide table contains columns for all subtypes.
def delete_table_by_name(self, name: str, if_exists: bool = True) -> None
-
Delete table in db.
def existing_tables(self) -> List[str]
-
Return existing tables in db.
def get_columns_mapping(self, type_: Type) -> Dict[str, str]
-
Collect all types in hierarchy and check type conflicts for fields with the same name.
def get_primary_keys(self, type_: Type) -> Tuple[str, ...]
-
Return list of primary key fields.
def table_name_for_type(self, type_: Type) -> str
-
Return table name for the given type.