Module: fireworks_llama_llm
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
import fireworks.client # noqa
from cl.convince.llms.llama.llama_llm import LlamaLlm
from cl.convince.llms.llm import Llm
from cl.convince.settings.fireworks_settings import FireworksSettings
@dataclass(slots=True, kw_only=True)
class FireworksLlamaLlm(LlamaLlm):
"""Implements LLAMA API running in the Fireworks cloud."""
model_name: str | None = None
"""Model name in Fireworks format including version if any, defaults to 'llm_id'."""
max_tokens: int = 4096
"""Maximum number of tokens the model will generate in response to the query."""
def uncached_completion(self, request_id: str, query: str) -> str:
"""Perform completion without CompletionCache lookup, call completion instead."""
# Prefix a unique RequestID to the model for audit log purposes and
# to stop model provider from caching the results
query_with_request_id = f"RequestID: {request_id}nn{query}"
model_name = self.model_name if self.model_name is not None else self.llm_id
prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>
{query_with_request_id}<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>"""
fireworks.client.api_key = FireworksSettings.instance().api_key
response = fireworks.client.Completion.create(
model=f"accounts/fireworks/models/{model_name}", prompt=prompt, max_tokens=self.max_tokens
)
result = response.choices[0].text
return result
Classes
class FireworksLlamaLlm (*, llm_id: str = None, model_name: str | None = None, max_tokens: int = 4096)
-
Implements LLAMA API running in the Fireworks cloud.
Expand source code
@dataclass(slots=True, kw_only=True) class FireworksLlamaLlm(LlamaLlm): """Implements LLAMA API running in the Fireworks cloud.""" model_name: str | None = None """Model name in Fireworks format including version if any, defaults to 'llm_id'.""" max_tokens: int = 4096 """Maximum number of tokens the model will generate in response to the query.""" def uncached_completion(self, request_id: str, query: str) -> str: """Perform completion without CompletionCache lookup, call completion instead.""" # Prefix a unique RequestID to the model for audit log purposes and # to stop model provider from caching the results query_with_request_id = f"RequestID: {request_id}nn{query}" model_name = self.model_name if self.model_name is not None else self.llm_id prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|> {query_with_request_id}<|eot_id|> <|start_header_id|>assistant<|end_header_id|>""" fireworks.client.api_key = FireworksSettings.instance().api_key response = fireworks.client.Completion.create( model=f"accounts/fireworks/models/{model_name}", prompt=prompt, max_tokens=self.max_tokens ) result = response.choices[0].text return result
Ancestors
- LlamaLlm
- Llm
- LlmKey
- KeyMixin
- RecordMixin
- abc.ABC
- typing.Generic
Static methods
def create_prompt_from_messages(messages: list[dict]) -> str
-
Inherited from:
LlamaLlm
.create_prompt_from_messages
Transforms list of messages in the following format: [ {“role”: “system”, “content”: “System Prompt”}, {“role”: “user”, “content”: “What is 2 …
def create_prompt_from_messages_v2(messages: list[dict]) -> str
-
Inherited from:
LlamaLlm
.create_prompt_from_messages_v2
Transforms list of messages in the following format: [ {“role”: “system”, “content”: “System Prompt”}, {“role”: “user”, “content”: “What is 2 …
def get_key_type() -> Type
-
Inherited from:
LlamaLlm
.get_key_type
Return key type even when called from a record.
Fields
var llm_id -> str
-
Inherited from:
LlamaLlm
.llm_id
Unique LLM identifier.
var max_tokens -> int
-
Maximum number of tokens the model will generate in response to the query.
var model_name -> str | None
-
Model name in Fireworks format including version if any, defaults to ‘llm_id’.
Methods
def completion(self, query: str, *, trial_id: str | int | None = None) -> str
-
Inherited from:
LlamaLlm
.completion
Text-in, text-out single query completion without model-specific tags (uses response caching).
def get_key(self) -> LlmKey
-
Inherited from:
LlamaLlm
.get_key
Return a new key object whose fields populated from self, do not return self.
def init_all(self) -> None
-
Inherited from:
LlamaLlm
.init_all
Invoke ‘init’ for each class in the order from base to derived, then validate against schema.
def uncached_completion(self, request_id: str, query: str) -> str
-
Inherited from:
LlamaLlm
.uncached_completion
Perform completion without CompletionCache lookup, call completion instead.