from dataclasses import dataclass from typing import Any from langchain_core.messages import BaseMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStoreRetriever from langchain_ollama import ChatOllama from reviewllama.vector_store import get_context_from_store from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: chain: RunnableSerializable def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: """Create the base chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, temperature=config.temperature, ) prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" return inputs.get("context", "") return ( RunnablePassthrough.assign( context=RunnableLambda(get_context), ) | prompt | llm ) def create_chat_client(config: OllamaConfig): base_chain = create_chat_chain(config) return ChatClient( chain=base_chain, ) def chat_with_client( client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None, session_id: str = "default", verbose: bool = False, ) -> str: """Chat with the client and return the response content.""" if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( {"input": message, "context": context}, config={"configurable": {"session_id": session_id}}, ) return response.content