from dataclasses import dataclass from typing import Any from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStoreRetriever from langchain_ollama import ChatOllama from reviewllama.vector_store import get_context_from_store from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: chain: RunnableWithMessageHistory message_history: dict[str, BaseChatMessageHistory] def get_last_response_or_none(self, session_id: str = "default"): try: messages = self.message_history[session_id].messages return messages[-1] if messages else None except (KeyError, IndexError): return None def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: """Create the base chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, temperature=config.temperature, ) prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), MessagesPlaceholder("history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" return inputs.get("context", "") return ( RunnablePassthrough.assign( context=RunnableLambda(get_context), ) | prompt | llm ) def get_session_history( session_id: str, message_history: dict[str, BaseChatMessageHistory] ) -> BaseChatMessageHistory: """Get or create chat message history for a session.""" if session_id not in message_history: message_history[session_id] = ChatMessageHistory() return message_history[session_id] def create_chat_client(config: OllamaConfig): base_chain = create_chat_chain(config) message_history = {} chain_with_history = RunnableWithMessageHistory( base_chain, lambda session_id: get_session_history(session_id, message_history), input_messages_key="input", history_messages_key="history", ) return ChatClient( chain=chain_with_history, message_history=message_history, ) def chat_with_client( client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None, session_id: str = "default", verbose: bool = False, ) -> str: """Chat with the client and return the response content.""" if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( {"input": message, "context": context}, config={"configurable": {"session_id": session_id}}, ) return response.content