From 1c75cfc71608d656d301aa0da853b419c4da9e88 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Fri, 18 Jul 2025 22:21:31 -0400 Subject: [PATCH] Improve LangChain memory implementation - Swap to RunnableWithMemory - Add verbosity flag --- src/reviewllama/cli.py | 8 ++++ src/reviewllama/configs.py | 20 ++++++--- src/reviewllama/llm.py | 78 +++++++++++++++++++--------------- src/reviewllama/reviewllama.py | 17 +++++--- tests/test_llm.py | 9 ++-- 5 files changed, 81 insertions(+), 51 deletions(-) diff --git a/src/reviewllama/cli.py b/src/reviewllama/cli.py index 43fa565..289237d 100644 --- a/src/reviewllama/cli.py +++ b/src/reviewllama/cli.py @@ -90,6 +90,14 @@ Examples: ), help="Base branch to compare against (default: %(default)s)", ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Enable verbose output, including messages sent to the LLM.", + ) return parser diff --git a/src/reviewllama/configs.py b/src/reviewllama/configs.py index d577915..45e36cb 100644 --- a/src/reviewllama/configs.py +++ b/src/reviewllama/configs.py @@ -23,6 +23,7 @@ class ReviewConfig: paths: List[Path] ollama: OllamaConfig base_branch: str + verbose: bool def create_ollama_config( @@ -43,19 +44,24 @@ def create_ollama_config( def create_review_config( - paths: List[Path], ollama_config: OllamaConfig, base_branch: str + paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose ) -> ReviewConfig: """Create complete ReviewConfig from validated components.""" - return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch) + return ReviewConfig( + paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose + ) -def namespace_to_config( - namespace: argparse.Namespace -): +def namespace_to_config(namespace: argparse.Namespace): """Transform argparse namespace into ReviewConfig.""" paths = [Path(path_str) for path_str in namespace.paths] ollama_config = OllamaConfig( - chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model + chat_model=namespace.model, + base_url=namespace.server_url, + system_prompt=namespace.system_prompt, + embedding_model=namespace.embedding_model, ) - return create_review_config(paths, ollama_config, namespace.base_branch) + return create_review_config( + paths, ollama_config, namespace.base_branch, namespace.verbose + ) diff --git a/src/reviewllama/llm.py b/src/reviewllama/llm.py index 73a7265..fd2522a 100644 --- a/src/reviewllama/llm.py +++ b/src/reviewllama/llm.py @@ -1,11 +1,13 @@ from dataclasses import dataclass from typing import Any -from langchain.memory import ConversationBufferMemory -from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain.schema import BaseMessage +from langchain_community.chat_message_histories import ChatMessageHistory +from langchain_core.chat_history import BaseChatMessageHistory +from langchain_core.messages import BaseMessage +from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable +from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStoreRetriever from langchain_ollama import ChatOllama @@ -17,20 +19,21 @@ from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: - chain: RunnableSerializable[dict[str, Any], BaseMessage] - memory: ConversationBufferMemory + chain: RunnableWithMessageHistory + message_history: dict[str, BaseChatMessageHistory] - def get_last_response_or_none(self): + def get_last_response_or_none(self, session_id: str = "default"): try: - return self.memory.chat_memory.messages[-1] - except IndexError: + messages = self.message_history[session_id].messages + return messages[-1] if messages else None + except (KeyError, IndexError): return None def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: - """Create the chat chain for use by the code.""" + """Create the base chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, @@ -40,28 +43,17 @@ def create_chat_chain( prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), - MessagesPlaceholder("chat_history"), + MessagesPlaceholder("history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) - def get_chat_history(inputs: dict) -> list: - """Extract chat history from memory object.""" - try: - return inputs["memory"].chat_memory.messages - except AttributeError: - return [] - def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" - try: - return inputs["context"] - except AttributeError: - return "" + return inputs.get("context", "") return ( RunnablePassthrough.assign( - chat_history=RunnableLambda(get_chat_history), context=RunnableLambda(get_context), ) | prompt @@ -69,32 +61,48 @@ def create_chat_chain( ) -def create_memory(): - return ConversationBufferMemory(memory_key="chat_history", return_messages=True) +def get_session_history( + session_id: str, message_history: dict[str, BaseChatMessageHistory] +) -> BaseChatMessageHistory: + """Get or create chat message history for a session.""" + if session_id not in message_history: + message_history[session_id] = ChatMessageHistory() + return message_history[session_id] def create_chat_client(config: OllamaConfig): + base_chain = create_chat_chain(config) + message_history = {} + + chain_with_history = RunnableWithMessageHistory( + base_chain, + lambda session_id: get_session_history(session_id, message_history), + input_messages_key="input", + history_messages_key="history", + ) + return ChatClient( - chain=create_chat_chain(config), - memory=create_memory(), + chain=chain_with_history, + message_history=message_history, ) def chat_with_client( - client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None -) -> ChatClient: - + client: ChatClient, + message: str, + retriever: VectorStoreRetriever | None = None, + session_id: str = "default", + verbose: bool = False, +) -> str: + """Chat with the client and return the response content.""" if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( - {"input": message, "memory": client.memory, "context": context} + {"input": message, "context": context}, + config={"configurable": {"session_id": session_id}}, ) - memory = client.memory - memory.chat_memory.add_user_message(message) - memory.chat_memory.add_ai_message(response.content) - - return ChatClient(chain=client.chain, memory=memory) + return response.content diff --git a/src/reviewllama/reviewllama.py b/src/reviewllama/reviewllama.py index ae0cbba..25ca909 100644 --- a/src/reviewllama/reviewllama.py +++ b/src/reviewllama/reviewllama.py @@ -4,8 +4,12 @@ from git import Repo from langchain_core.vectorstores import VectorStoreRetriever from reviewllama.configs import OllamaConfig, ReviewConfig -from reviewllama.git_diff import (GitAnalysis, GitDiff, analyze_git_repository, - get_tracked_files) +from reviewllama.git_diff import ( + GitAnalysis, + GitDiff, + analyze_git_repository, + get_tracked_files, +) from reviewllama.llm import ChatClient, chat_with_client, create_chat_client from reviewllama.vector_store import create_retriever @@ -19,7 +23,7 @@ def run_reviewllama(config: ReviewConfig): retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) for diff in analysis.diffs: - chat_client = get_suggestions(diff, retriever, chat_client) + chat_client = get_suggestions(diff, retriever, chat_client, config.verbose) def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: @@ -44,9 +48,12 @@ def create_and_log_vector_store_retriever( def get_suggestions( - diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient + diff: GitDiff, + retriever: VectorStoreRetriever, + chat_client: ChatClient, + verbose: bool, ) -> ChatClient: - new_client = chat_with_client(chat_client, craft_message(diff), retriever) + new_client = chat_with_client(chat_client, craft_message(diff), retriever, verbose) log_info(str(new_client.get_last_response_or_none().content)) return new_client diff --git a/tests/test_llm.py b/tests/test_llm.py index faa3a4e..3551850 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -17,11 +17,12 @@ def test_chat_client(ollama_config, chat_client): if not is_ollama_available(ollama_config): pytest.skip("Local Ollama server is not available") - chat_client = chat_with_client( + response = chat_with_client( chat_client, "Tell me your name and introduce yourself briefly" ) - response = chat_client.get_last_response_or_none() + response_from_history = chat_client.get_last_response_or_none().content assert response is not None - assert len(response.content) > 0 - assert "gemma" in response.content.lower() + assert response == response_from_history + assert len(response) > 0 + assert "gemma" in response.lower()