Module: llm

Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from abc import abstractmethod
from dataclasses import dataclass
from cl.runtime.primitive.timestamp import Timestamp
from cl.runtime.records.record_mixin import RecordMixin
from cl.convince.llms.completion_cache import CompletionCache
from cl.convince.llms.llm_key import LlmKey


@dataclass(slots=True, kw_only=True)
class Llm(LlmKey, RecordMixin[LlmKey], ABC):
    """Provides an API for single query and chat completion."""

    _completion_cache: CompletionCache | None = None
    """Completion cache is used to return cached LLM responses."""

    def get_key(self) -> LlmKey:
        return LlmKey(llm_id=self.llm_id)

    def completion(self, query: str, *, trial_id: str | int | None = None) -> str:
        """Text-in, text-out single query completion without model-specific tags (uses response caching)."""

        # Remove leading and trailing whitespace and normalize EOL in query
        query = CompletionCache.normalize_value(query)

        # Create completion cache if does not exist
        if self._completion_cache is None:
            self._completion_cache = CompletionCache(channel=self.llm_id)

        # Try to find in completion cache by cache_key, make cloud provider call only if not found
        if (result := self._completion_cache.get(query, trial_id=trial_id)) is None:
            # Request identifier is UUIDv7 timestamp in time-ordered dash-delimited format
            # is used to prevent LLM cloud provider caching and to identify LLM API calls
            # for audit log and error reporting purposes
            request_id = Timestamp.create()

            # Invoke LLM by calling the cloud provider API
            result = self.uncached_completion(request_id, query)

            # Save the result in cache before returning, request_id is recorded
            # but not taken into account during lookup
            self._completion_cache.add(request_id, query, result, trial_id=trial_id)

        # Remove leading and trailing whitespace and normalize EOL in result
        result = CompletionCache.normalize_value(result)
        return result

    @abstractmethod
    def uncached_completion(self, request_id: str, query: str) -> str:
        """Perform completion without CompletionCache lookup, call completion instead."""

Classes

class Llm (*, llm_id: str = None)

Provides an API for single query and chat completion.

Expand source code
@dataclass(slots=True, kw_only=True)
class Llm(LlmKey, RecordMixin[LlmKey], ABC):
    """Provides an API for single query and chat completion."""

    _completion_cache: CompletionCache | None = None
    """Completion cache is used to return cached LLM responses."""

    def get_key(self) -> LlmKey:
        return LlmKey(llm_id=self.llm_id)

    def completion(self, query: str, *, trial_id: str | int | None = None) -> str:
        """Text-in, text-out single query completion without model-specific tags (uses response caching)."""

        # Remove leading and trailing whitespace and normalize EOL in query
        query = CompletionCache.normalize_value(query)

        # Create completion cache if does not exist
        if self._completion_cache is None:
            self._completion_cache = CompletionCache(channel=self.llm_id)

        # Try to find in completion cache by cache_key, make cloud provider call only if not found
        if (result := self._completion_cache.get(query, trial_id=trial_id)) is None:
            # Request identifier is UUIDv7 timestamp in time-ordered dash-delimited format
            # is used to prevent LLM cloud provider caching and to identify LLM API calls
            # for audit log and error reporting purposes
            request_id = Timestamp.create()

            # Invoke LLM by calling the cloud provider API
            result = self.uncached_completion(request_id, query)

            # Save the result in cache before returning, request_id is recorded
            # but not taken into account during lookup
            self._completion_cache.add(request_id, query, result, trial_id=trial_id)

        # Remove leading and trailing whitespace and normalize EOL in result
        result = CompletionCache.normalize_value(result)
        return result

    @abstractmethod
    def uncached_completion(self, request_id: str, query: str) -> str:
        """Perform completion without CompletionCache lookup, call completion instead."""

Ancestors

Subclasses

Static methods

def get_key_type() -> Type

Inherited from: LlmKey.get_key_type

Return key type even when called from a record.

Fields

var llm_id -> str

Inherited from: LlmKey.llm_id

Unique LLM identifier.

Methods

def completion(self, query: str, *, trial_id: str | int | None = None) -> str

Text-in, text-out single query completion without model-specific tags (uses response caching).

def get_key(self) -> LlmKey

Inherited from: RecordMixin.get_key

Return a new key object whose fields populated from self, do not return self.

def init_all(self) -> None

Inherited from: RecordMixin.init_all

Invoke ‘init’ for each class in the order from base to derived, then validate against schema.

def uncached_completion(self, request_id: str, query: str) -> str

Perform completion without CompletionCache lookup, call completion instead.