from dataclasses import dataclass from typing import Any from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.schema import BaseMessage from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStoreRetriever from langchain_ollama import ChatOllama from reviewllama.vector_store import get_context_from_store from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: chain: RunnableSerializable[dict[str, Any], BaseMessage] memory: ConversationBufferMemory def get_last_response_or_none(self): try: return self.memory.chat_memory.messages[-1] except IndexError: return None def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: """Create the chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, temperature=config.temperature, ) prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), MessagesPlaceholder("chat_history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) def get_chat_history(inputs: dict) -> list: """Extract chat history from memory object.""" try: return inputs["memory"].chat_memory.messages except AttributeError: return [] def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" try: return inputs["context"] except AttributeError: return "" return ( RunnablePassthrough.assign( chat_history=RunnableLambda(get_chat_history), context=RunnableLambda(get_context), ) | prompt | llm ) def create_memory(): return ConversationBufferMemory(memory_key="chat_history", return_messages=True) def create_chat_client(config: OllamaConfig): return ChatClient( chain=create_chat_chain(config), memory=create_memory(), ) def chat_with_client( client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None ) -> ChatClient: if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( {"input": message, "memory": client.memory, "context": context} ) memory = client.memory memory.chat_memory.add_user_message(message) memory.chat_memory.add_ai_message(response.content) return ChatClient(chain=client.chain, memory=memory)