Module: completion_cache
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import csv
import os
from dataclasses import dataclass
from typing import Any
from typing import Dict
from typing import Iterable
from cl.runtime.context.env_util import EnvUtil
from cl.runtime.records.dataclasses_extensions import field
from cl.runtime.settings.context_settings import ContextSettings
from cl.runtime.settings.project_settings import ProjectSettings
_supported_extensions = ["csv"]
"""The list of supported output file extensions (formats)."""
_csv_headers = ["RequestID", "Query", "Completion"]
"""CSV column headers."""
def _error_extension_not_supported(ext: str) -> Any:
raise RuntimeError(
f"Extension {ext} is not supported by CompletionCache. "
f"Supported extensions: {', '.join(_supported_extensions)}"
)
@dataclass(slots=True, kw_only=True)
class CompletionCache:
"""
Cache LLM completions for reducing AI cost (disable when testing the LLM itself)
Notes:
- After each model call, input and output are recorded in 'channel.completions.csv'
- The channel may be based on llm_id or include some of all of the LLM settings or their hash
- If exactly the same input is subsequently found in the completions file, it is used without calling the LLM
- To record a new completions file, delete the existing one
"""
channel: str | None = None
"""Dot-delimited string or an iterable of dot-delimited tokens to uniquely identify the cache."""
ext: str | None = None
"""Output file extension (format) without the dot prefix, defaults to 'csv'."""
output_path: str | None = None
"""Path for the cache file where completions are stored."""
__completion_dict: Dict[str, str] = field(default_factory=lambda: {}) # TODO: Set using ContextVars
"""Dictionary of completions indexed by query."""
def __post_init__(self):
"""
Load the completions file from disk once on construction. New completions added to this instance
are written to disk but not reused.
"""
# Find base_path=dir_path/test_module by examining call stack for test function signature test_*
# Directory 'project_root/completions' is used when not running under a test
default_dir = os.path.join(ContextSettings.instance().get_project_root(), "completions")
base_dir = EnvUtil.get_env_dir(default_dir=default_dir)
# If not found, use base path relative to project root
if base_dir is None:
project_root = ProjectSettings.get_project_root()
base_dir = os.path.join(project_root, "completions")
if self.ext is not None:
# Remove dot prefix if specified
self.ext = self.ext.removeprefix(".")
if self.ext not in _supported_extensions:
_error_extension_not_supported(self.ext)
else:
# Use csv if not specified
self.ext = "csv"
# Cache file path
if self.channel is None or self.channel == "":
cache_filename = f"completions.{self.ext}"
else:
cache_filename = f"{self.channel}.completions.{self.ext}"
self.output_path = os.path.join(base_dir, cache_filename)
# Load cache file from disk
self.load_cache_file()
def add(self, request_id: str, query: str, completion: str, *, trial_id: str | int | None = None) -> None:
"""Add to file even if already exits, the latest will take precedence during lookup."""
# Check if the file already exists
is_new = not os.path.exists(self.output_path)
# If file does not exist, create directory if directory does not exist
if is_new:
output_dir = os.path.dirname(self.output_path)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if self.ext == "csv":
with open(self.output_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(
file,
delimiter=",",
quotechar='"',
quoting=csv.QUOTE_MINIMAL,
escapechar="\",
lineterminator=os.linesep,
)
if is_new:
# Write the headers if the file is new
writer.writerow(self.to_os_eol(_csv_headers))
# NOT ADDING THE VALUE TO COMPLETION DICT HERE IS NOT A BUG
# Because we are not adding to the dict here but only writing to a file,
# the model will not reuse cached completions within the same session,
# preventing incorrect measurement of stability
# Get cache key with trial_id, EOL normalization, and stripped leading and trailing whitespace
cache_key = self.normalize_key(query, trial_id=trial_id)
# Remove leading and trailing whitespace and normalize EOL in value
cached_value = self.normalize_value(completion)
# Write the new completion without checking if one already exists
writer.writerow(self.to_os_eol([request_id, cache_key, cached_value]))
# Flush immediately to ensure all of the output is on disk in the event of exception
file.flush()
else:
# Should not be reached here because of a previous check in __init__
_error_extension_not_supported(self.ext)
def get(self, query: str, *, trial_id: str | int | None = None) -> str | None:
"""Return completion for the specified query if found and None otherwise."""
# Add trial_id, strip leading and trailing whitespace, and normalize EOL
cache_key = self.normalize_key(query, trial_id=trial_id)
# Look up with trial ID
result = self.__completion_dict.get(cache_key, None)
if result is not None:
# Remove leading and trailing whitespace and normalize EOL in value
result = self.normalize_value(result)
return result
def load_cache_file(self) -> None:
"""Load cache file."""
if os.path.exists(self.output_path):
# Populate the dictionary from file if exists but not yet loaded
with open(self.output_path, mode="r", newline="", encoding="utf-8") as file:
reader = csv.reader(file, delimiter=",", quotechar='"', escapechar="\", lineterminator=os.linesep)
# Read and validate the headers
headers_in_file = next(reader, None)
if headers_in_file != _csv_headers:
max_len = 20
headers_in_file = [h if len(h) < max_len else f"{h[:max_len]}..." for h in headers_in_file]
headers_in_file_str = ", ".join(headers_in_file)
expected_headers_str = ", ".join(_csv_headers)
raise ValueError(
f"Expected column headers in completions cache are {expected_headers_str}. "
f"Actual headers: {headers_in_file_str}."
)
# Read cached completions, ignoring request_id at position 0
self.__completion_dict.update({row[1]: row[2] for row_ in reader if (row := self.to_python_eol(row_))})
@classmethod
def normalize_key(cls, query: str, trial_id: str | int | None = None) -> str:
"""Add trial_id, strip leading and trailing whitespace, and normalize EOL."""
# Strip leading and trailing whitespace and EOL
result = query.strip()
# Add trial_id to the beginning of cached query key
if trial_id is not None:
result = f"TrialID: {str(trial_id)}n{result}"
# Normalize EOL
result = cls.to_python_eol(result)
return result
@classmethod
def normalize_value(cls, value: str) -> str:
"""Strip leading and trailing whitespace, and normalize EOL."""
# Strip leading and trailing whitespace and EOL
result = value.strip()
# Normalize EOL
result = cls.to_python_eol(result)
return result
@classmethod
def to_python_eol(cls, data: Iterable[str] | str | None):
"""Convert all types of EOL to n for Python strings."""
if data is None:
return None
if not isinstance(data, str) and isinstance(data, collections.abc.Iterable):
# If data is iterable return list of adjusted elements
# Convert EOL only, do not strip leading or trailing whitespace
return [cls.to_python_eol(x) for x in data]
else:
# Replace endings format to n
data = data.replace("rrn", "n")
data = data.replace("rn", "n")
return data
@classmethod
def to_os_eol(cls, data: Iterable[str] | str | None):
"""Convert all types of EOL to 'os.linesep' for writing the file to disk."""
if data is None:
return None
if not isinstance(data, str) and isinstance(data, collections.abc.Iterable):
# If data is iterable return list of adjusted elements
return [cls.to_os_eol(x) for x in data]
else:
# Raise an exception if data contains os.linesep characters that are not n, since
# they will be lost after normalization.
if os.linesep != "n" and os.linesep in data:
raise RuntimeError("Can not normalize data contains os.linesep characters that are not \n.")
# Replace n to os.linesep
adjusted_data = data.replace("n", os.linesep)
return adjusted_data
Classes
class CompletionCache (*, channel: str | None = None, ext: str | None = None, output_path: str | None = None)
-
Cache LLM completions for reducing AI cost (disable when testing the LLM itself)
Notes
- After each model call, input and output are recorded in ‘channel.completions.csv’
- The channel may be based on llm_id or include some of all of the LLM settings or their hash
- If exactly the same input is subsequently found in the completions file, it is used without calling the LLM
- To record a new completions file, delete the existing one
Expand source code
@dataclass(slots=True, kw_only=True) class CompletionCache: """ Cache LLM completions for reducing AI cost (disable when testing the LLM itself) Notes: - After each model call, input and output are recorded in 'channel.completions.csv' - The channel may be based on llm_id or include some of all of the LLM settings or their hash - If exactly the same input is subsequently found in the completions file, it is used without calling the LLM - To record a new completions file, delete the existing one """ channel: str | None = None """Dot-delimited string or an iterable of dot-delimited tokens to uniquely identify the cache.""" ext: str | None = None """Output file extension (format) without the dot prefix, defaults to 'csv'.""" output_path: str | None = None """Path for the cache file where completions are stored.""" __completion_dict: Dict[str, str] = field(default_factory=lambda: {}) # TODO: Set using ContextVars """Dictionary of completions indexed by query.""" def __post_init__(self): """ Load the completions file from disk once on construction. New completions added to this instance are written to disk but not reused. """ # Find base_path=dir_path/test_module by examining call stack for test function signature test_* # Directory 'project_root/completions' is used when not running under a test default_dir = os.path.join(ContextSettings.instance().get_project_root(), "completions") base_dir = EnvUtil.get_env_dir(default_dir=default_dir) # If not found, use base path relative to project root if base_dir is None: project_root = ProjectSettings.get_project_root() base_dir = os.path.join(project_root, "completions") if self.ext is not None: # Remove dot prefix if specified self.ext = self.ext.removeprefix(".") if self.ext not in _supported_extensions: _error_extension_not_supported(self.ext) else: # Use csv if not specified self.ext = "csv" # Cache file path if self.channel is None or self.channel == "": cache_filename = f"completions.{self.ext}" else: cache_filename = f"{self.channel}.completions.{self.ext}" self.output_path = os.path.join(base_dir, cache_filename) # Load cache file from disk self.load_cache_file() def add(self, request_id: str, query: str, completion: str, *, trial_id: str | int | None = None) -> None: """Add to file even if already exits, the latest will take precedence during lookup.""" # Check if the file already exists is_new = not os.path.exists(self.output_path) # If file does not exist, create directory if directory does not exist if is_new: output_dir = os.path.dirname(self.output_path) if not os.path.exists(output_dir): os.makedirs(output_dir) if self.ext == "csv": with open(self.output_path, mode="a", newline="", encoding="utf-8") as file: writer = csv.writer( file, delimiter=",", quotechar='"', quoting=csv.QUOTE_MINIMAL, escapechar="\", lineterminator=os.linesep, ) if is_new: # Write the headers if the file is new writer.writerow(self.to_os_eol(_csv_headers)) # NOT ADDING THE VALUE TO COMPLETION DICT HERE IS NOT A BUG # Because we are not adding to the dict here but only writing to a file, # the model will not reuse cached completions within the same session, # preventing incorrect measurement of stability # Get cache key with trial_id, EOL normalization, and stripped leading and trailing whitespace cache_key = self.normalize_key(query, trial_id=trial_id) # Remove leading and trailing whitespace and normalize EOL in value cached_value = self.normalize_value(completion) # Write the new completion without checking if one already exists writer.writerow(self.to_os_eol([request_id, cache_key, cached_value])) # Flush immediately to ensure all of the output is on disk in the event of exception file.flush() else: # Should not be reached here because of a previous check in __init__ _error_extension_not_supported(self.ext) def get(self, query: str, *, trial_id: str | int | None = None) -> str | None: """Return completion for the specified query if found and None otherwise.""" # Add trial_id, strip leading and trailing whitespace, and normalize EOL cache_key = self.normalize_key(query, trial_id=trial_id) # Look up with trial ID result = self.__completion_dict.get(cache_key, None) if result is not None: # Remove leading and trailing whitespace and normalize EOL in value result = self.normalize_value(result) return result def load_cache_file(self) -> None: """Load cache file.""" if os.path.exists(self.output_path): # Populate the dictionary from file if exists but not yet loaded with open(self.output_path, mode="r", newline="", encoding="utf-8") as file: reader = csv.reader(file, delimiter=",", quotechar='"', escapechar="\", lineterminator=os.linesep) # Read and validate the headers headers_in_file = next(reader, None) if headers_in_file != _csv_headers: max_len = 20 headers_in_file = [h if len(h) < max_len else f"{h[:max_len]}..." for h in headers_in_file] headers_in_file_str = ", ".join(headers_in_file) expected_headers_str = ", ".join(_csv_headers) raise ValueError( f"Expected column headers in completions cache are {expected_headers_str}. " f"Actual headers: {headers_in_file_str}." ) # Read cached completions, ignoring request_id at position 0 self.__completion_dict.update({row[1]: row[2] for row_ in reader if (row := self.to_python_eol(row_))}) @classmethod def normalize_key(cls, query: str, trial_id: str | int | None = None) -> str: """Add trial_id, strip leading and trailing whitespace, and normalize EOL.""" # Strip leading and trailing whitespace and EOL result = query.strip() # Add trial_id to the beginning of cached query key if trial_id is not None: result = f"TrialID: {str(trial_id)}n{result}" # Normalize EOL result = cls.to_python_eol(result) return result @classmethod def normalize_value(cls, value: str) -> str: """Strip leading and trailing whitespace, and normalize EOL.""" # Strip leading and trailing whitespace and EOL result = value.strip() # Normalize EOL result = cls.to_python_eol(result) return result @classmethod def to_python_eol(cls, data: Iterable[str] | str | None): """Convert all types of EOL to n for Python strings.""" if data is None: return None if not isinstance(data, str) and isinstance(data, collections.abc.Iterable): # If data is iterable return list of adjusted elements # Convert EOL only, do not strip leading or trailing whitespace return [cls.to_python_eol(x) for x in data] else: # Replace endings format to n data = data.replace("rrn", "n") data = data.replace("rn", "n") return data @classmethod def to_os_eol(cls, data: Iterable[str] | str | None): """Convert all types of EOL to 'os.linesep' for writing the file to disk.""" if data is None: return None if not isinstance(data, str) and isinstance(data, collections.abc.Iterable): # If data is iterable return list of adjusted elements return [cls.to_os_eol(x) for x in data] else: # Raise an exception if data contains os.linesep characters that are not n, since # they will be lost after normalization. if os.linesep != "n" and os.linesep in data: raise RuntimeError("Can not normalize data contains os.linesep characters that are not \n.") # Replace n to os.linesep adjusted_data = data.replace("n", os.linesep) return adjusted_data
Static methods
def normalize_key(query: str, trial_id: str | int | None = None) -> str
-
Add trial_id, strip leading and trailing whitespace, and normalize EOL.
def normalize_value(value: str) -> str
-
Strip leading and trailing whitespace, and normalize EOL.
def to_os_eol(data: Union[Iterable[str], str, ForwardRef(None)])
-
Convert all types of EOL to ‘os.linesep’ for writing the file to disk.
def to_python_eol(data: Union[Iterable[str], str, ForwardRef(None)])
-
Convert all types of EOL to for Python strings.
Fields
var channel -> str | None
-
Dot-delimited string or an iterable of dot-delimited tokens to uniquely identify the cache.
var ext -> str | None
-
Output file extension (format) without the dot prefix, defaults to ‘csv’.
var output_path -> str | None
-
Path for the cache file where completions are stored.
Methods
def add(self, request_id: str, query: str, completion: str, *, trial_id: str | int | None = None) -> None
-
Add to file even if already exits, the latest will take precedence during lookup.
def get(self, query: str, *, trial_id: str | int | None = None) -> str | None
-
Return completion for the specified query if found and None otherwise.
def load_cache_file(self) -> None
-
Load cache file.