Remove message history
This commit is contained in:
parent
1c75cfc716
commit
c228a7298a
3 changed files with 8 additions and 45 deletions
|
@ -1,13 +1,10 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
||||||
from langchain_core.chat_history import BaseChatMessageHistory
|
|
||||||
from langchain_core.messages import BaseMessage
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||||
from langchain_core.runnables import RunnableLambda
|
from langchain_core.runnables import RunnableLambda
|
||||||
from langchain_core.runnables.base import RunnableSerializable
|
from langchain_core.runnables.base import RunnableSerializable
|
||||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
from langchain_core.vectorstores import VectorStoreRetriever
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
from langchain_ollama import ChatOllama
|
from langchain_ollama import ChatOllama
|
||||||
|
@ -19,15 +16,7 @@ from .configs import OllamaConfig
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ChatClient:
|
class ChatClient:
|
||||||
chain: RunnableWithMessageHistory
|
chain: RunnableSerializable
|
||||||
message_history: dict[str, BaseChatMessageHistory]
|
|
||||||
|
|
||||||
def get_last_response_or_none(self, session_id: str = "default"):
|
|
||||||
try:
|
|
||||||
messages = self.message_history[session_id].messages
|
|
||||||
return messages[-1] if messages else None
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_chain(
|
def create_chat_chain(
|
||||||
|
@ -43,7 +32,6 @@ def create_chat_chain(
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
("system", config.system_prompt),
|
("system", config.system_prompt),
|
||||||
MessagesPlaceholder("history"),
|
|
||||||
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
@ -61,29 +49,11 @@ def create_chat_chain(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_session_history(
|
|
||||||
session_id: str, message_history: dict[str, BaseChatMessageHistory]
|
|
||||||
) -> BaseChatMessageHistory:
|
|
||||||
"""Get or create chat message history for a session."""
|
|
||||||
if session_id not in message_history:
|
|
||||||
message_history[session_id] = ChatMessageHistory()
|
|
||||||
return message_history[session_id]
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_client(config: OllamaConfig):
|
def create_chat_client(config: OllamaConfig):
|
||||||
base_chain = create_chat_chain(config)
|
base_chain = create_chat_chain(config)
|
||||||
message_history = {}
|
|
||||||
|
|
||||||
chain_with_history = RunnableWithMessageHistory(
|
|
||||||
base_chain,
|
|
||||||
lambda session_id: get_session_history(session_id, message_history),
|
|
||||||
input_messages_key="input",
|
|
||||||
history_messages_key="history",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatClient(
|
return ChatClient(
|
||||||
chain=chain_with_history,
|
chain=base_chain,
|
||||||
message_history=message_history,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,8 @@ from git import Repo
|
||||||
from langchain_core.vectorstores import VectorStoreRetriever
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
|
|
||||||
from reviewllama.configs import OllamaConfig, ReviewConfig
|
from reviewllama.configs import OllamaConfig, ReviewConfig
|
||||||
from reviewllama.git_diff import (
|
from reviewllama.git_diff import (GitAnalysis, GitDiff, analyze_git_repository,
|
||||||
GitAnalysis,
|
get_tracked_files)
|
||||||
GitDiff,
|
|
||||||
analyze_git_repository,
|
|
||||||
get_tracked_files,
|
|
||||||
)
|
|
||||||
from reviewllama.llm import ChatClient, chat_with_client, create_chat_client
|
from reviewllama.llm import ChatClient, chat_with_client, create_chat_client
|
||||||
from reviewllama.vector_store import create_retriever
|
from reviewllama.vector_store import create_retriever
|
||||||
|
|
||||||
|
@ -23,7 +19,7 @@ def run_reviewllama(config: ReviewConfig):
|
||||||
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
|
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
|
||||||
|
|
||||||
for diff in analysis.diffs:
|
for diff in analysis.diffs:
|
||||||
chat_client = get_suggestions(diff, retriever, chat_client, config.verbose)
|
get_suggestions(diff, retriever, chat_client, config.verbose)
|
||||||
|
|
||||||
|
|
||||||
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
|
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
|
||||||
|
@ -52,10 +48,9 @@ def get_suggestions(
|
||||||
retriever: VectorStoreRetriever,
|
retriever: VectorStoreRetriever,
|
||||||
chat_client: ChatClient,
|
chat_client: ChatClient,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
) -> ChatClient:
|
):
|
||||||
new_client = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
|
response = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
|
||||||
log_info(str(new_client.get_last_response_or_none().content))
|
log_info(response)
|
||||||
return new_client
|
|
||||||
|
|
||||||
|
|
||||||
def craft_message(diff) -> str:
|
def craft_message(diff) -> str:
|
||||||
|
|
|
@ -21,8 +21,6 @@ def test_chat_client(ollama_config, chat_client):
|
||||||
chat_client, "Tell me your name and introduce yourself briefly"
|
chat_client, "Tell me your name and introduce yourself briefly"
|
||||||
)
|
)
|
||||||
|
|
||||||
response_from_history = chat_client.get_last_response_or_none().content
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response == response_from_history
|
|
||||||
assert len(response) > 0
|
assert len(response) > 0
|
||||||
assert "gemma" in response.lower()
|
assert "gemma" in response.lower()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue