ReviewLlama/src/reviewllama/llm.py

78 lines
2.1 KiB
Python

from dataclasses import dataclass
from typing import Any
from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_ollama import ChatOllama
from reviewllama.vector_store import get_context_from_store
from .configs import OllamaConfig
@dataclass(frozen=True)
class ChatClient:
chain: RunnableSerializable
def create_chat_chain(
config: OllamaConfig,
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
"""Create the base chat chain for use by the code."""
llm = ChatOllama(
model=config.chat_model,
base_url=config.base_url,
temperature=config.temperature,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", config.system_prompt),
("human", "Context\n{context}\n\nQuestion:\n{input}"),
]
)
def get_context(inputs: dict) -> str:
"""Extract the RAG context from the input object"""
return inputs.get("context", "")
return (
RunnablePassthrough.assign(
context=RunnableLambda(get_context),
)
| prompt
| llm
)
def create_chat_client(config: OllamaConfig):
base_chain = create_chat_chain(config)
return ChatClient(
chain=base_chain,
)
def chat_with_client(
client: ChatClient,
message: str,
retriever: VectorStoreRetriever | None = None,
session_id: str = "default",
verbose: bool = False,
) -> str:
"""Chat with the client and return the response content."""
if retriever:
context = get_context_from_store(message, retriever)
else:
context = ""
response = client.chain.invoke(
{"input": message, "context": context},
config={"configurable": {"session_id": session_id}},
)
return response.content