Module: key_util
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ast
import inspect
import textwrap
from typing import List
from typing import Type
class KeyUtil:
"""Utilities for working with keys."""
# TODO: Extract from key class instead
@classmethod
def get_key_fields(cls, record_type: Type) -> List[str] | None:
"""
Get primary key fields by parsing the source of 'get_key' method of 'record_type'.
Notes:
This method parses the source code of 'get_key' method of 'record_type' and returns all
instance fields it accesses in the order of access, for example if 'get_key' source is:
def get_key(self) -> MyKey:
return MyKey(key_field_1=self.key_field_1, key_field_2=self.key_field_2)
this method will return:
["key_field_1", "key_field_2"]
Args:
record_type: Class where 'get_key' method is implemented
"""
# Get source code for the 'get_key' method
if hasattr(record_type, "get_key"):
get_key_source = inspect.getsource(record_type.get_key)
else:
# TODO: Determine if a flag is needed for element types to prevent keys lookup
return None
# raise RuntimeError(
# f"Cannot get key fields because record type {record_type.__name__} "
# f"does not implement 'get_key' method."
# )
# Because 'ast' expects the code to be correct as though it is at top level,
# remove excess indent from the source to make it suitable for parsing
get_key_source = textwrap.dedent(get_key_source)
# Extract field names from the AST of 'get_key' method
get_key_ast = ast.parse(get_key_source)
key_fields = []
for node in ast.walk(get_key_ast):
# Find every instance field of accessed inside the source of 'get_key' method.
# Accumulate in list in the order they are accessed
if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "self":
key_fields.append(node.attr)
return key_fields
Classes
class KeyUtil
-
Utilities for working with keys.
Expand source code
class KeyUtil: """Utilities for working with keys.""" # TODO: Extract from key class instead @classmethod def get_key_fields(cls, record_type: Type) -> List[str] | None: """ Get primary key fields by parsing the source of 'get_key' method of 'record_type'. Notes: This method parses the source code of 'get_key' method of 'record_type' and returns all instance fields it accesses in the order of access, for example if 'get_key' source is: def get_key(self) -> MyKey: return MyKey(key_field_1=self.key_field_1, key_field_2=self.key_field_2) this method will return: ["key_field_1", "key_field_2"] Args: record_type: Class where 'get_key' method is implemented """ # Get source code for the 'get_key' method if hasattr(record_type, "get_key"): get_key_source = inspect.getsource(record_type.get_key) else: # TODO: Determine if a flag is needed for element types to prevent keys lookup return None # raise RuntimeError( # f"Cannot get key fields because record type {record_type.__name__} " # f"does not implement 'get_key' method." # ) # Because 'ast' expects the code to be correct as though it is at top level, # remove excess indent from the source to make it suitable for parsing get_key_source = textwrap.dedent(get_key_source) # Extract field names from the AST of 'get_key' method get_key_ast = ast.parse(get_key_source) key_fields = [] for node in ast.walk(get_key_ast): # Find every instance field of accessed inside the source of 'get_key' method. # Accumulate in list in the order they are accessed if isinstance(node, ast.Attribute) and isinstance(node.value, ast.Name) and node.value.id == "self": key_fields.append(node.attr) return key_fields
Static methods
def get_key_fields(record_type: Type) -> Optional[List[str]]
-
Get primary key fields by parsing the source of ‘get_key’ method of ‘record_type’.
Notes
This method parses the source code of ‘get_key’ method of ‘record_type’ and returns all instance fields it accesses in the order of access, for example if ‘get_key’ source is:
def get_key(self) -> MyKey: return MyKey(key_field_1=self.key_field_1, key_field_2=self.key_field_2)
this method will return:
[“key_field_1”, “key_field_2”]
Args
record_type
- Class where ‘get_key’ method is implemented