Module: annotating_retriever

Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import re
from dataclasses import dataclass
from typing import List
from cl.runtime import Context
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.primitive.string_util import StringUtil
from cl.runtime.records.dataclasses_extensions import missing
from cl.convince.entries.entry import Entry
from cl.convince.llms.gpt.gpt_llm import GptLlm
from cl.convince.llms.llm import Llm
from cl.convince.llms.llm_key import LlmKey
from cl.convince.prompts.formatted_prompt import FormattedPrompt
from cl.convince.prompts.prompt import Prompt
from cl.convince.prompts.prompt_key import PromptKey
from cl.convince.retrievers.retrieval import Retrieval
from cl.convince.retrievers.retriever import Retriever
from cl.convince.retrievers.retriever_util import RetrieverUtil

_TRIPLE_BACKTICKS_RE = re.compile(r"```(.*?)```", re.DOTALL)
"""Regex for text between triple backticks."""

_BRACES_RE = re.compile(r"{(.*?)}")
"""Regex for text between curly braces."""

_TEMPLATE = """You will be provided with an input text and a description of a parameter.
Your goal is to surround each piece of information about this parameter you find in the input text by curly braces.
Use multiple non-nested pairs of opening and closing curly braces if you find more than one piece of information.

You must reply with JSON formatted strictly according to the JSON specification in which all values are strings.
The JSON must have the following keys:

{{
    "success": <Y if at least one piece of information was found and N otherwise. This parameter is required.>
    "annotated_text": "<The input text where each piece of information about this parameter is surrounded by curly braces. There should be no changes other than adding curly braces, even to whitespace. Leave this field empty in case of failure.>,"
    "justification": "<Justification for your annotations in case of success or the reason why you were not able to find the parameter in case of failure.>"
}}
Input text: ```{InputText}```
Parameter description: ```{ParamDescription}```
"""


@dataclass(slots=True, kw_only=True)
class AnnotatingRetriever(Retriever):
    """Instructs the model to surround the requested parameter by curly braces and uses the annotations to retrieve."""

    llm: LlmKey = missing()
    """LLM used to perform the retrieval."""

    prompt: PromptKey = missing()
    """Prompt used to perform the retrieval."""

    def init(self) -> None:
        """Same as __init__ but can be used when field values are set both during and after construction."""
        if self.llm is None:
            self.llm = GptLlm(llm_id="gpt-4o")  # TODO: Review the handling of defaults
        if self.prompt is None:
            self.prompt = FormattedPrompt(
                prompt_id="AnnotatingRetriever",
                params_type=Retrieval.__name__,
                template=_TEMPLATE,
            )  # TODO: Review the handling of defaults

    def retrieve(
        self,
        entry_id: str,  # TODO: Generate instead
        input_text: str,
        param_description: str,
        param_samples: List[str] | None = None,
    ) -> str:
        # Get LLM and prompt
        llm = Context.current().load_one(Llm, self.llm)
        prompt = Context.current().load_one(Prompt, self.prompt)

        # Strip starting and ending whitespace
        input_text = input_text.strip()  # TODO: Perform more advanced normalization

        # Create a retrieval record
        retrieval = Retrieval(
            retrieval_id=f"{entry_id}: {param_description}",
            input_text=input_text,
            param_description=param_description,
            param_samples=param_samples,
        )

        # Create braces extraction prompt
        rendered_prompt = prompt.render(params=retrieval)

        trial_count = 2
        for trial_index in range(trial_count):
            is_last_trial = trial_index == trial_count - 1
            try:
                # Get text annotated with braces and check that the only difference is braces and whitespace
                completion = llm.completion(rendered_prompt, trial_id=trial_index)

                # Extract the results
                json_result = RetrieverUtil.extract_json(completion)  # TODO(Major): Do not depend on stubs
                if json_result is None:
                    raise UserError(f"NCould not extract JSON from the LLM response. LLM response:n{completion}n")
                success_text = json_result.get("success", None)
                annotated_text = json_result.get("annotated_text", None)
                justification = json_result.get("justification", None)

                # Go to the next trial in case of failure
                success = Entry.parse_required_bool(success_text, field_name="success_text")
                if not success:
                    continue

                if StringUtil.is_not_empty(annotated_text):
                    # Compare after removing the curly brackets
                    to_compare = self._deannotate(annotated_text)
                    if to_compare != input_text:
                        if not is_last_trial:
                            # Continue if not the last trial
                            continue
                        else:
                            # Otherwise report an error
                            # TODO: Use unified diff
                            raise UserError(
                                f"Annotated text has changes other than curly braces.n"
                                f"Input text: ```{input_text}```n"
                                f"Annotated text: ```{annotated_text}```n"
                            )
                else:
                    raise RuntimeError(
                        f"Extraction success reported by {llm.llm_id}, however "
                        f"the annotated text is empty. Input text:n{input_text}n"
                    )

                # Extract data inside braces
                matches = re.findall(_BRACES_RE, annotated_text)
                for match in matches:
                    if "{" in match or "}" in match:
                        if not is_last_trial:
                            continue
                        else:
                            raise UserError(
                                f"Nested curly braces are present in annotated text.n"
                                f"Annotated text: ```{annotated_text}```n"
                            )

                # Combine and return from inside the loop
                # TODO: Determine if numbered combination works better
                param_value = " ".join(matches)

                # Populate the output fields and save the retrieval object for validation
                retrieval.success = True
                retrieval.param_value = param_value
                retrieval.justification = justification
                Context.current().save_one(retrieval)

                # Return only the parameter value
                return param_value

            except Exception as e:
                if is_last_trial:
                    # Rethrow only when the last trial is reached
                    raise UserError(
                        f"Unable to extract parameter from the input text after {trial_count} trials.n"
                        f"Input text: {input_text}n"
                        f"Parameter description: {param_description}n"
                        f"Last trial error information: {str(e)}n"
                    )
                else:
                    # Otherwise log the error details and continue
                    pass  # TODO: Log failure with info message level

        # The method should always return from the loop, adding as a backup in case this changes in the future
        raise UserError(
            f"Unable to extract parameter from the input text.n"
            f"Input text: {input_text}n"
            f"Parameter description: {param_description}n"
        )

    @classmethod
    def _extract_annotated(cls, text: str) -> str:
        # Find all occurrences of triple backticks and the text inside them
        matches = re.findall(_TRIPLE_BACKTICKS_RE, text)
        if len(matches) == 0:
            raise RuntimeError("No string found between triple backticks in: ", text)
        elif len(matches) > 1:
            raise RuntimeError("More than one string found between triple backticks in: ", text)
        result = matches[0].strip()
        return result

    @classmethod
    def _extract_in_braces(
        cls, annotated_text: str, *, continue_on_error: bool | None = None
    ) -> str | None:  # TODO: Move to Util class
        """
        Extract the blocks inside curly braces.

        Notes:
            - Return as semicolon-delimited string if more than one block is found
            - If continue_on_error is True, return None without raising an error
        """
        matches = re.findall(_BRACES_RE, annotated_text)
        if len(matches) == 0:
            if continue_on_error:
                return None
            else:
                raise UserError(
                    f"No curly braces are present in annotated text.n" f"Annotated text: ```{annotated_text}```n"
                )
        if any("{" in match or "}" in match for match in matches):
            if continue_on_error:
                return None
            else:
                raise UserError(
                    f"Nested curly braces are present in annotated text.n" f"Annotated text: ```{annotated_text}```n"
                )

        # Combine using semicolon delimiter and return
        result = ";".join(matches)
        return result

    @classmethod
    def _deannotate(cls, text: str) -> str:
        # Remove triple backticks and curly brackets
        result = text.replace("`", "").strip().replace("{", "").replace("}", "").strip()
        return result

Classes

class AnnotatingRetriever (*, retriever_id: str = None, llm: LlmKey = None, prompt: PromptKey = None)

Instructs the model to surround the requested parameter by curly braces and uses the annotations to retrieve.

Expand source code
@dataclass(slots=True, kw_only=True)
class AnnotatingRetriever(Retriever):
    """Instructs the model to surround the requested parameter by curly braces and uses the annotations to retrieve."""

    llm: LlmKey = missing()
    """LLM used to perform the retrieval."""

    prompt: PromptKey = missing()
    """Prompt used to perform the retrieval."""

    def init(self) -> None:
        """Same as __init__ but can be used when field values are set both during and after construction."""
        if self.llm is None:
            self.llm = GptLlm(llm_id="gpt-4o")  # TODO: Review the handling of defaults
        if self.prompt is None:
            self.prompt = FormattedPrompt(
                prompt_id="AnnotatingRetriever",
                params_type=Retrieval.__name__,
                template=_TEMPLATE,
            )  # TODO: Review the handling of defaults

    def retrieve(
        self,
        entry_id: str,  # TODO: Generate instead
        input_text: str,
        param_description: str,
        param_samples: List[str] | None = None,
    ) -> str:
        # Get LLM and prompt
        llm = Context.current().load_one(Llm, self.llm)
        prompt = Context.current().load_one(Prompt, self.prompt)

        # Strip starting and ending whitespace
        input_text = input_text.strip()  # TODO: Perform more advanced normalization

        # Create a retrieval record
        retrieval = Retrieval(
            retrieval_id=f"{entry_id}: {param_description}",
            input_text=input_text,
            param_description=param_description,
            param_samples=param_samples,
        )

        # Create braces extraction prompt
        rendered_prompt = prompt.render(params=retrieval)

        trial_count = 2
        for trial_index in range(trial_count):
            is_last_trial = trial_index == trial_count - 1
            try:
                # Get text annotated with braces and check that the only difference is braces and whitespace
                completion = llm.completion(rendered_prompt, trial_id=trial_index)

                # Extract the results
                json_result = RetrieverUtil.extract_json(completion)  # TODO(Major): Do not depend on stubs
                if json_result is None:
                    raise UserError(f"NCould not extract JSON from the LLM response. LLM response:n{completion}n")
                success_text = json_result.get("success", None)
                annotated_text = json_result.get("annotated_text", None)
                justification = json_result.get("justification", None)

                # Go to the next trial in case of failure
                success = Entry.parse_required_bool(success_text, field_name="success_text")
                if not success:
                    continue

                if StringUtil.is_not_empty(annotated_text):
                    # Compare after removing the curly brackets
                    to_compare = self._deannotate(annotated_text)
                    if to_compare != input_text:
                        if not is_last_trial:
                            # Continue if not the last trial
                            continue
                        else:
                            # Otherwise report an error
                            # TODO: Use unified diff
                            raise UserError(
                                f"Annotated text has changes other than curly braces.n"
                                f"Input text: ```{input_text}```n"
                                f"Annotated text: ```{annotated_text}```n"
                            )
                else:
                    raise RuntimeError(
                        f"Extraction success reported by {llm.llm_id}, however "
                        f"the annotated text is empty. Input text:n{input_text}n"
                    )

                # Extract data inside braces
                matches = re.findall(_BRACES_RE, annotated_text)
                for match in matches:
                    if "{" in match or "}" in match:
                        if not is_last_trial:
                            continue
                        else:
                            raise UserError(
                                f"Nested curly braces are present in annotated text.n"
                                f"Annotated text: ```{annotated_text}```n"
                            )

                # Combine and return from inside the loop
                # TODO: Determine if numbered combination works better
                param_value = " ".join(matches)

                # Populate the output fields and save the retrieval object for validation
                retrieval.success = True
                retrieval.param_value = param_value
                retrieval.justification = justification
                Context.current().save_one(retrieval)

                # Return only the parameter value
                return param_value

            except Exception as e:
                if is_last_trial:
                    # Rethrow only when the last trial is reached
                    raise UserError(
                        f"Unable to extract parameter from the input text after {trial_count} trials.n"
                        f"Input text: {input_text}n"
                        f"Parameter description: {param_description}n"
                        f"Last trial error information: {str(e)}n"
                    )
                else:
                    # Otherwise log the error details and continue
                    pass  # TODO: Log failure with info message level

        # The method should always return from the loop, adding as a backup in case this changes in the future
        raise UserError(
            f"Unable to extract parameter from the input text.n"
            f"Input text: {input_text}n"
            f"Parameter description: {param_description}n"
        )

    @classmethod
    def _extract_annotated(cls, text: str) -> str:
        # Find all occurrences of triple backticks and the text inside them
        matches = re.findall(_TRIPLE_BACKTICKS_RE, text)
        if len(matches) == 0:
            raise RuntimeError("No string found between triple backticks in: ", text)
        elif len(matches) > 1:
            raise RuntimeError("More than one string found between triple backticks in: ", text)
        result = matches[0].strip()
        return result

    @classmethod
    def _extract_in_braces(
        cls, annotated_text: str, *, continue_on_error: bool | None = None
    ) -> str | None:  # TODO: Move to Util class
        """
        Extract the blocks inside curly braces.

        Notes:
            - Return as semicolon-delimited string if more than one block is found
            - If continue_on_error is True, return None without raising an error
        """
        matches = re.findall(_BRACES_RE, annotated_text)
        if len(matches) == 0:
            if continue_on_error:
                return None
            else:
                raise UserError(
                    f"No curly braces are present in annotated text.n" f"Annotated text: ```{annotated_text}```n"
                )
        if any("{" in match or "}" in match for match in matches):
            if continue_on_error:
                return None
            else:
                raise UserError(
                    f"Nested curly braces are present in annotated text.n" f"Annotated text: ```{annotated_text}```n"
                )

        # Combine using semicolon delimiter and return
        result = ";".join(matches)
        return result

    @classmethod
    def _deannotate(cls, text: str) -> str:
        # Remove triple backticks and curly brackets
        result = text.replace("`", "").strip().replace("{", "").replace("}", "").strip()
        return result

Ancestors

Static methods

def get_key_type() -> Type

Inherited from: Retriever.get_key_type

Return key type even when called from a record.

Fields

var llm -> LlmKey

LLM used to perform the retrieval.

var prompt -> PromptKey

Prompt used to perform the retrieval.

var retriever_id -> str

Inherited from: Retriever.retriever_id

Unique retriever identifier.

Methods

def get_key(self) -> RetrieverKey

Inherited from: Retriever.get_key

Return a new key object whose fields populated from self, do not return self.

def init(self) -> None

Same as init but can be used when field values are set both during and after construction.

def init_all(self) -> None

Inherited from: Retriever.init_all

Invoke ‘init’ for each class in the order from base to derived, then validate against schema.

def retrieve(self, entry_id: str, input_text: str, param_description: str, param_samples: Optional[List[str]] = None) -> str

Inherited from: Retriever.retrieve

Retrieve the specified parameter from the entry and return it as a smaller entry …