ReviewLlama/src/reviewllama/llm.py

101 lines
2.8 KiB
Python

from dataclasses import dataclass
from pathlib import Path
from typing import Any
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import BaseMessage
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_ollama import ChatOllama
from reviewllama.vector_store import get_context_from_store
from .configs import OllamaConfig
@dataclass(frozen=True)
class ChatClient:
chain: RunnableSerializable[dict[str, Any], BaseMessage]
memory: ConversationBufferMemory
def get_last_response_or_none(self):
try:
return self.memory.chat_memory.messages[-1]
except IndexError:
return None
def create_chat_chain(
config: OllamaConfig,
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
"""Create the chat chain for use by the code."""
llm = ChatOllama(
model=config.chat_model,
base_url=config.base_url,
temperature=config.temperature,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", config.system_prompt),
MessagesPlaceholder("chat_history"),
("human", "Context\n{context}\n\nQuestion:\n{input}"),
]
)
def get_chat_history(inputs: dict) -> list:
"""Extract chat history from memory object."""
try:
return inputs["memory"].chat_memory.messages
except AttributeError:
return []
def get_context(inputs: dict) -> str:
"""Extract the RAG context from the input object"""
try:
return inputs["context"]
except AttributeError:
return ""
return (
RunnablePassthrough.assign(
chat_history=RunnableLambda(get_chat_history),
context=RunnableLambda(get_context),
)
| prompt
| llm
)
def create_memory():
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def create_chat_client(config: OllamaConfig):
return ChatClient(
chain=create_chat_chain(config),
memory=create_memory(),
)
def chat_with_client(
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
) -> ChatClient:
if retriever:
context = get_context_from_store(message, retriever)
else:
context = ""
response = client.chain.invoke(
{"input": message, "memory": client.memory, "context": context}
)
memory = client.memory
memory.chat_memory.add_user_message(message)
memory.chat_memory.add_ai_message(response.content)
return ChatClient(chain=client.chain, memory=memory)