Module: template_prompt
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from cl.runtime import RecordMixin
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.primitive.case_util import CaseUtil
from cl.runtime.records.dataclasses_extensions import missing
from cl.runtime.records.protocols import TDataDict
from cl.runtime.serialization.dict_serializer import DictSerializer
from cl.runtime.serialization.string_serializer import StringSerializer
from cl.convince.prompts.prompt import Prompt
from cl.convince.prompts.prompt_key import PromptKey
_data_serializer = DictSerializer()
"""Serializer used to serialize params object for rendering the template."""
_key_serializer = StringSerializer()
"""Serializer used to serialize keys for error reporting."""
@dataclass(slots=True, kw_only=True)
class TemplatePrompt(Prompt, ABC):
"""Uses a template to render the prompt, param names are PascalCase in curly braces."""
template: str = missing()
"""Uses a template to render the prompt, param names are PascalCase in curly braces."""
def get_key(self) -> PromptKey:
return PromptKey(prompt_id=self.prompt_id)
def render(self, params: RecordMixin) -> str:
"""Use data from the parameters object of 'params_type' to render the template."""
# Check params type
self._check_params_type(params)
# Serialize and convert keys to PascalCase
params_dict = _data_serializer.serialize_data(params)
params_dict_with_pascal_case_keys = {
CaseUtil.snake_to_pascal_case(k): v for k, v in params_dict.items() if v is not None
}
# Render
try:
result = self._render(params_dict_with_pascal_case_keys)
except KeyError as e:
field_name = str(e)
params_key_str = _key_serializer.serialize_key(params.get_key())
present_keys_str = "".join(f" - {x}n" for x in params_dict_with_pascal_case_keys.keys())
raise UserError(
f"Parameter required by prompt is either None or not a field of the parameters object.n"
f"Prompt key='{self.prompt_id}'n"
f"Parameter name: {field_name}n"
f"Parameters object type={type(params).__name__} and key='{params_key_str}'n"
f"Available non-empty fields of the parameters object:n{present_keys_str}n"
)
return result
@abstractmethod
def _render(self, dict_with_pascal_case_keys: TDataDict) -> str:
"""Protected method performing the actual rendering, must throw KeyError if a parameter is missing."""
Classes
class TemplatePrompt (*, prompt_id: str = None, params_type: str = None, template: str = None)
-
Uses a template to render the prompt, param names are PascalCase in curly braces.
Expand source code
@dataclass(slots=True, kw_only=True) class TemplatePrompt(Prompt, ABC): """Uses a template to render the prompt, param names are PascalCase in curly braces.""" template: str = missing() """Uses a template to render the prompt, param names are PascalCase in curly braces.""" def get_key(self) -> PromptKey: return PromptKey(prompt_id=self.prompt_id) def render(self, params: RecordMixin) -> str: """Use data from the parameters object of 'params_type' to render the template.""" # Check params type self._check_params_type(params) # Serialize and convert keys to PascalCase params_dict = _data_serializer.serialize_data(params) params_dict_with_pascal_case_keys = { CaseUtil.snake_to_pascal_case(k): v for k, v in params_dict.items() if v is not None } # Render try: result = self._render(params_dict_with_pascal_case_keys) except KeyError as e: field_name = str(e) params_key_str = _key_serializer.serialize_key(params.get_key()) present_keys_str = "".join(f" - {x}n" for x in params_dict_with_pascal_case_keys.keys()) raise UserError( f"Parameter required by prompt is either None or not a field of the parameters object.n" f"Prompt key='{self.prompt_id}'n" f"Parameter name: {field_name}n" f"Parameters object type={type(params).__name__} and key='{params_key_str}'n" f"Available non-empty fields of the parameters object:n{present_keys_str}n" ) return result @abstractmethod def _render(self, dict_with_pascal_case_keys: TDataDict) -> str: """Protected method performing the actual rendering, must throw KeyError if a parameter is missing."""
Ancestors
- Prompt
- PromptKey
- KeyMixin
- RecordMixin
- abc.ABC
- typing.Generic
Subclasses
Static methods
def get_key_type() -> Type
-
Inherited from:
Prompt
.get_key_type
Return key type even when called from a record.
Fields
var params_type -> str
-
Inherited from:
Prompt
.params_type
Record whose pascalized fields are used as template parameters in ClassName format.
var prompt_id -> str
-
Inherited from:
Prompt
.prompt_id
Unique prompt identifier.
var template -> str
-
Uses a template to render the prompt, param names are PascalCase in curly braces.
Methods
def get_key(self) -> PromptKey
-
Inherited from:
Prompt
.get_key
Return a new key object whose fields populated from self, do not return self.
def init_all(self) -> None
-
Inherited from:
Prompt
.init_all
Invoke ‘init’ for each class in the order from base to derived, then validate against schema.
def render(self, params: RecordMixin) -> str
-
Use data from the parameters object of ‘params_type’ to render the template.