diff --git a/src/reviewllama/llm.py b/src/reviewllama/llm.py index fd2522a..f86cf3f 100644 --- a/src/reviewllama/llm.py +++ b/src/reviewllama/llm.py @@ -1,13 +1,10 @@ from dataclasses import dataclass from typing import Any -from langchain_community.chat_message_histories import ChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.messages import BaseMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable -from langchain_core.runnables.history import RunnableWithMessageHistory from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.vectorstores import VectorStoreRetriever from langchain_ollama import ChatOllama @@ -19,15 +16,7 @@ from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: - chain: RunnableWithMessageHistory - message_history: dict[str, BaseChatMessageHistory] - - def get_last_response_or_none(self, session_id: str = "default"): - try: - messages = self.message_history[session_id].messages - return messages[-1] if messages else None - except (KeyError, IndexError): - return None + chain: RunnableSerializable def create_chat_chain( @@ -43,7 +32,6 @@ def create_chat_chain( prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), - MessagesPlaceholder("history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) @@ -61,29 +49,11 @@ def create_chat_chain( ) -def get_session_history( - session_id: str, message_history: dict[str, BaseChatMessageHistory] -) -> BaseChatMessageHistory: - """Get or create chat message history for a session.""" - if session_id not in message_history: - message_history[session_id] = ChatMessageHistory() - return message_history[session_id] - - def create_chat_client(config: OllamaConfig): base_chain = create_chat_chain(config) - message_history = {} - - chain_with_history = RunnableWithMessageHistory( - base_chain, - lambda session_id: get_session_history(session_id, message_history), - input_messages_key="input", - history_messages_key="history", - ) return ChatClient( - chain=chain_with_history, - message_history=message_history, + chain=base_chain, ) diff --git a/src/reviewllama/reviewllama.py b/src/reviewllama/reviewllama.py index 25ca909..a1cf01f 100644 --- a/src/reviewllama/reviewllama.py +++ b/src/reviewllama/reviewllama.py @@ -4,12 +4,8 @@ from git import Repo from langchain_core.vectorstores import VectorStoreRetriever from reviewllama.configs import OllamaConfig, ReviewConfig -from reviewllama.git_diff import ( - GitAnalysis, - GitDiff, - analyze_git_repository, - get_tracked_files, -) +from reviewllama.git_diff import (GitAnalysis, GitDiff, analyze_git_repository, + get_tracked_files) from reviewllama.llm import ChatClient, chat_with_client, create_chat_client from reviewllama.vector_store import create_retriever @@ -23,7 +19,7 @@ def run_reviewllama(config: ReviewConfig): retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) for diff in analysis.diffs: - chat_client = get_suggestions(diff, retriever, chat_client, config.verbose) + get_suggestions(diff, retriever, chat_client, config.verbose) def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: @@ -52,10 +48,9 @@ def get_suggestions( retriever: VectorStoreRetriever, chat_client: ChatClient, verbose: bool, -) -> ChatClient: - new_client = chat_with_client(chat_client, craft_message(diff), retriever, verbose) - log_info(str(new_client.get_last_response_or_none().content)) - return new_client +): + response = chat_with_client(chat_client, craft_message(diff), retriever, verbose) + log_info(response) def craft_message(diff) -> str: diff --git a/tests/test_llm.py b/tests/test_llm.py index 3551850..77fe6f5 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -21,8 +21,6 @@ def test_chat_client(ollama_config, chat_client): chat_client, "Tell me your name and introduce yourself briefly" ) - response_from_history = chat_client.get_last_response_or_none().content assert response is not None - assert response == response_from_history assert len(response) > 0 assert "gemma" in response.lower()