Remove message history

This commit is contained in:
Alex Selimov 2025-08-01 14:21:18 -04:00
parent 1c75cfc716
commit c228a7298a
Signed by: aselimov
GPG key ID: 3DDB9C3E023F1F31
3 changed files with 8 additions and 45 deletions

View file

@ -1,13 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.messages import BaseMessage from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import RunnableLambda from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever from langchain_core.vectorstores import VectorStoreRetriever
from langchain_ollama import ChatOllama from langchain_ollama import ChatOllama
@ -19,15 +16,7 @@ from .configs import OllamaConfig
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatClient: class ChatClient:
chain: RunnableWithMessageHistory chain: RunnableSerializable
message_history: dict[str, BaseChatMessageHistory]
def get_last_response_or_none(self, session_id: str = "default"):
try:
messages = self.message_history[session_id].messages
return messages[-1] if messages else None
except (KeyError, IndexError):
return None
def create_chat_chain( def create_chat_chain(
@ -43,7 +32,6 @@ def create_chat_chain(
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [
("system", config.system_prompt), ("system", config.system_prompt),
MessagesPlaceholder("history"),
("human", "Context\n{context}\n\nQuestion:\n{input}"), ("human", "Context\n{context}\n\nQuestion:\n{input}"),
] ]
) )
@ -61,29 +49,11 @@ def create_chat_chain(
) )
def get_session_history(
session_id: str, message_history: dict[str, BaseChatMessageHistory]
) -> BaseChatMessageHistory:
"""Get or create chat message history for a session."""
if session_id not in message_history:
message_history[session_id] = ChatMessageHistory()
return message_history[session_id]
def create_chat_client(config: OllamaConfig): def create_chat_client(config: OllamaConfig):
base_chain = create_chat_chain(config) base_chain = create_chat_chain(config)
message_history = {}
chain_with_history = RunnableWithMessageHistory(
base_chain,
lambda session_id: get_session_history(session_id, message_history),
input_messages_key="input",
history_messages_key="history",
)
return ChatClient( return ChatClient(
chain=chain_with_history, chain=base_chain,
message_history=message_history,
) )

View file

@ -4,12 +4,8 @@ from git import Repo
from langchain_core.vectorstores import VectorStoreRetriever from langchain_core.vectorstores import VectorStoreRetriever
from reviewllama.configs import OllamaConfig, ReviewConfig from reviewllama.configs import OllamaConfig, ReviewConfig
from reviewllama.git_diff import ( from reviewllama.git_diff import (GitAnalysis, GitDiff, analyze_git_repository,
GitAnalysis, get_tracked_files)
GitDiff,
analyze_git_repository,
get_tracked_files,
)
from reviewllama.llm import ChatClient, chat_with_client, create_chat_client from reviewllama.llm import ChatClient, chat_with_client, create_chat_client
from reviewllama.vector_store import create_retriever from reviewllama.vector_store import create_retriever
@ -23,7 +19,7 @@ def run_reviewllama(config: ReviewConfig):
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
for diff in analysis.diffs: for diff in analysis.diffs:
chat_client = get_suggestions(diff, retriever, chat_client, config.verbose) get_suggestions(diff, retriever, chat_client, config.verbose)
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
@ -52,10 +48,9 @@ def get_suggestions(
retriever: VectorStoreRetriever, retriever: VectorStoreRetriever,
chat_client: ChatClient, chat_client: ChatClient,
verbose: bool, verbose: bool,
) -> ChatClient: ):
new_client = chat_with_client(chat_client, craft_message(diff), retriever, verbose) response = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
log_info(str(new_client.get_last_response_or_none().content)) log_info(response)
return new_client
def craft_message(diff) -> str: def craft_message(diff) -> str:

View file

@ -21,8 +21,6 @@ def test_chat_client(ollama_config, chat_client):
chat_client, "Tell me your name and introduce yourself briefly" chat_client, "Tell me your name and introduce yourself briefly"
) )
response_from_history = chat_client.get_last_response_or_none().content
assert response is not None assert response is not None
assert response == response_from_history
assert len(response) > 0 assert len(response) > 0
assert "gemma" in response.lower() assert "gemma" in response.lower()