101 lines
2.8 KiB
Python
101 lines
2.8 KiB
Python
|
from dataclasses import dataclass
|
||
|
from typing import Any
|
||
|
|
||
|
from langchain.memory import ConversationBufferMemory
|
||
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||
|
from langchain.schema import BaseMessage
|
||
|
from langchain_core.runnables import RunnableLambda
|
||
|
from langchain_core.runnables.base import RunnableSerializable
|
||
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||
|
from langchain_ollama import ChatOllama
|
||
|
|
||
|
from reviewllama.vector_store import get_context_from_store
|
||
|
|
||
|
from .configs import OllamaConfig
|
||
|
|
||
|
|
||
|
@dataclass(frozen=True)
|
||
|
class ChatClient:
|
||
|
chain: RunnableSerializable[dict[str, Any], BaseMessage]
|
||
|
memory: ConversationBufferMemory
|
||
|
|
||
|
def get_last_response_or_none(self):
|
||
|
try:
|
||
|
return self.memory.chat_memory.messages[-1]
|
||
|
except IndexError:
|
||
|
return None
|
||
|
|
||
|
|
||
|
def create_chat_chain(
|
||
|
config: OllamaConfig,
|
||
|
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
||
|
"""Create the chat chain for use by the code."""
|
||
|
llm = ChatOllama(
|
||
|
model=config.chat_model,
|
||
|
base_url=config.base_url,
|
||
|
temperature=config.temperature,
|
||
|
)
|
||
|
|
||
|
prompt = ChatPromptTemplate.from_messages(
|
||
|
[
|
||
|
("system", config.system_prompt),
|
||
|
MessagesPlaceholder("chat_history"),
|
||
|
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||
|
]
|
||
|
)
|
||
|
|
||
|
def get_chat_history(inputs: dict) -> list:
|
||
|
"""Extract chat history from memory object."""
|
||
|
try:
|
||
|
return inputs["memory"].chat_memory.messages
|
||
|
except AttributeError:
|
||
|
return []
|
||
|
|
||
|
def get_context(inputs: dict) -> str:
|
||
|
"""Extract the RAG context from the input object"""
|
||
|
try:
|
||
|
return inputs["context"]
|
||
|
except AttributeError:
|
||
|
return ""
|
||
|
|
||
|
return (
|
||
|
RunnablePassthrough.assign(
|
||
|
chat_history=RunnableLambda(get_chat_history),
|
||
|
context=RunnableLambda(get_context),
|
||
|
)
|
||
|
| prompt
|
||
|
| llm
|
||
|
)
|
||
|
|
||
|
|
||
|
def create_memory():
|
||
|
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||
|
|
||
|
|
||
|
def create_chat_client(config: OllamaConfig):
|
||
|
return ChatClient(
|
||
|
chain=create_chat_chain(config),
|
||
|
memory=create_memory(),
|
||
|
)
|
||
|
|
||
|
|
||
|
def chat_with_client(
|
||
|
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
|
||
|
) -> ChatClient:
|
||
|
|
||
|
if retriever:
|
||
|
context = get_context_from_store(message, retriever)
|
||
|
else:
|
||
|
context = ""
|
||
|
|
||
|
response = client.chain.invoke(
|
||
|
{"input": message, "memory": client.memory, "context": context}
|
||
|
)
|
||
|
|
||
|
memory = client.memory
|
||
|
memory.chat_memory.add_user_message(message)
|
||
|
memory.chat_memory.add_ai_message(response.content)
|
||
|
|
||
|
return ChatClient(chain=client.chain, memory=memory)
|