Improve LangChain memory implementation
- Swap to RunnableWithMemory - Add verbosity flag
This commit is contained in:
parent
e59cf01ba9
commit
1c75cfc716
5 changed files with 81 additions and 51 deletions
|
@ -90,6 +90,14 @@ Examples:
|
|||
),
|
||||
help="Base branch to compare against (default: %(default)s)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable verbose output, including messages sent to the LLM.",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ class ReviewConfig:
|
|||
paths: List[Path]
|
||||
ollama: OllamaConfig
|
||||
base_branch: str
|
||||
verbose: bool
|
||||
|
||||
|
||||
def create_ollama_config(
|
||||
|
@ -43,19 +44,24 @@ def create_ollama_config(
|
|||
|
||||
|
||||
def create_review_config(
|
||||
paths: List[Path], ollama_config: OllamaConfig, base_branch: str
|
||||
paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose
|
||||
) -> ReviewConfig:
|
||||
"""Create complete ReviewConfig from validated components."""
|
||||
return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch)
|
||||
return ReviewConfig(
|
||||
paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose
|
||||
)
|
||||
|
||||
|
||||
def namespace_to_config(
|
||||
namespace: argparse.Namespace
|
||||
):
|
||||
def namespace_to_config(namespace: argparse.Namespace):
|
||||
"""Transform argparse namespace into ReviewConfig."""
|
||||
paths = [Path(path_str) for path_str in namespace.paths]
|
||||
ollama_config = OllamaConfig(
|
||||
chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model
|
||||
chat_model=namespace.model,
|
||||
base_url=namespace.server_url,
|
||||
system_prompt=namespace.system_prompt,
|
||||
embedding_model=namespace.embedding_model,
|
||||
)
|
||||
|
||||
return create_review_config(paths, ollama_config, namespace.base_branch)
|
||||
return create_review_config(
|
||||
paths, ollama_config, namespace.base_branch, namespace.verbose
|
||||
)
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from langchain_ollama import ChatOllama
|
||||
|
@ -17,20 +19,21 @@ from .configs import OllamaConfig
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ChatClient:
|
||||
chain: RunnableSerializable[dict[str, Any], BaseMessage]
|
||||
memory: ConversationBufferMemory
|
||||
chain: RunnableWithMessageHistory
|
||||
message_history: dict[str, BaseChatMessageHistory]
|
||||
|
||||
def get_last_response_or_none(self):
|
||||
def get_last_response_or_none(self, session_id: str = "default"):
|
||||
try:
|
||||
return self.memory.chat_memory.messages[-1]
|
||||
except IndexError:
|
||||
messages = self.message_history[session_id].messages
|
||||
return messages[-1] if messages else None
|
||||
except (KeyError, IndexError):
|
||||
return None
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
config: OllamaConfig,
|
||||
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
||||
"""Create the chat chain for use by the code."""
|
||||
"""Create the base chat chain for use by the code."""
|
||||
llm = ChatOllama(
|
||||
model=config.chat_model,
|
||||
base_url=config.base_url,
|
||||
|
@ -40,28 +43,17 @@ def create_chat_chain(
|
|||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", config.system_prompt),
|
||||
MessagesPlaceholder("chat_history"),
|
||||
MessagesPlaceholder("history"),
|
||||
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
def get_chat_history(inputs: dict) -> list:
|
||||
"""Extract chat history from memory object."""
|
||||
try:
|
||||
return inputs["memory"].chat_memory.messages
|
||||
except AttributeError:
|
||||
return []
|
||||
|
||||
def get_context(inputs: dict) -> str:
|
||||
"""Extract the RAG context from the input object"""
|
||||
try:
|
||||
return inputs["context"]
|
||||
except AttributeError:
|
||||
return ""
|
||||
return inputs.get("context", "")
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
chat_history=RunnableLambda(get_chat_history),
|
||||
context=RunnableLambda(get_context),
|
||||
)
|
||||
| prompt
|
||||
|
@ -69,32 +61,48 @@ def create_chat_chain(
|
|||
)
|
||||
|
||||
|
||||
def create_memory():
|
||||
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||
def get_session_history(
|
||||
session_id: str, message_history: dict[str, BaseChatMessageHistory]
|
||||
) -> BaseChatMessageHistory:
|
||||
"""Get or create chat message history for a session."""
|
||||
if session_id not in message_history:
|
||||
message_history[session_id] = ChatMessageHistory()
|
||||
return message_history[session_id]
|
||||
|
||||
|
||||
def create_chat_client(config: OllamaConfig):
|
||||
base_chain = create_chat_chain(config)
|
||||
message_history = {}
|
||||
|
||||
chain_with_history = RunnableWithMessageHistory(
|
||||
base_chain,
|
||||
lambda session_id: get_session_history(session_id, message_history),
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
)
|
||||
|
||||
return ChatClient(
|
||||
chain=create_chat_chain(config),
|
||||
memory=create_memory(),
|
||||
chain=chain_with_history,
|
||||
message_history=message_history,
|
||||
)
|
||||
|
||||
|
||||
def chat_with_client(
|
||||
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
|
||||
) -> ChatClient:
|
||||
|
||||
client: ChatClient,
|
||||
message: str,
|
||||
retriever: VectorStoreRetriever | None = None,
|
||||
session_id: str = "default",
|
||||
verbose: bool = False,
|
||||
) -> str:
|
||||
"""Chat with the client and return the response content."""
|
||||
if retriever:
|
||||
context = get_context_from_store(message, retriever)
|
||||
else:
|
||||
context = ""
|
||||
|
||||
response = client.chain.invoke(
|
||||
{"input": message, "memory": client.memory, "context": context}
|
||||
{"input": message, "context": context},
|
||||
config={"configurable": {"session_id": session_id}},
|
||||
)
|
||||
|
||||
memory = client.memory
|
||||
memory.chat_memory.add_user_message(message)
|
||||
memory.chat_memory.add_ai_message(response.content)
|
||||
|
||||
return ChatClient(chain=client.chain, memory=memory)
|
||||
return response.content
|
||||
|
|
|
@ -4,8 +4,12 @@ from git import Repo
|
|||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
|
||||
from reviewllama.configs import OllamaConfig, ReviewConfig
|
||||
from reviewllama.git_diff import (GitAnalysis, GitDiff, analyze_git_repository,
|
||||
get_tracked_files)
|
||||
from reviewllama.git_diff import (
|
||||
GitAnalysis,
|
||||
GitDiff,
|
||||
analyze_git_repository,
|
||||
get_tracked_files,
|
||||
)
|
||||
from reviewllama.llm import ChatClient, chat_with_client, create_chat_client
|
||||
from reviewllama.vector_store import create_retriever
|
||||
|
||||
|
@ -19,7 +23,7 @@ def run_reviewllama(config: ReviewConfig):
|
|||
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
|
||||
|
||||
for diff in analysis.diffs:
|
||||
chat_client = get_suggestions(diff, retriever, chat_client)
|
||||
chat_client = get_suggestions(diff, retriever, chat_client, config.verbose)
|
||||
|
||||
|
||||
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
|
||||
|
@ -44,9 +48,12 @@ def create_and_log_vector_store_retriever(
|
|||
|
||||
|
||||
def get_suggestions(
|
||||
diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient
|
||||
diff: GitDiff,
|
||||
retriever: VectorStoreRetriever,
|
||||
chat_client: ChatClient,
|
||||
verbose: bool,
|
||||
) -> ChatClient:
|
||||
new_client = chat_with_client(chat_client, craft_message(diff), retriever)
|
||||
new_client = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
|
||||
log_info(str(new_client.get_last_response_or_none().content))
|
||||
return new_client
|
||||
|
||||
|
|
|
@ -17,11 +17,12 @@ def test_chat_client(ollama_config, chat_client):
|
|||
if not is_ollama_available(ollama_config):
|
||||
pytest.skip("Local Ollama server is not available")
|
||||
|
||||
chat_client = chat_with_client(
|
||||
response = chat_with_client(
|
||||
chat_client, "Tell me your name and introduce yourself briefly"
|
||||
)
|
||||
response = chat_client.get_last_response_or_none()
|
||||
|
||||
response_from_history = chat_client.get_last_response_or_none().content
|
||||
assert response is not None
|
||||
assert len(response.content) > 0
|
||||
assert "gemma" in response.content.lower()
|
||||
assert response == response_from_history
|
||||
assert len(response) > 0
|
||||
assert "gemma" in response.lower()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue