Compare commits

...

3 commits

Author SHA1 Message Date
f17b6721f9
Update prompts 2025-08-05 10:18:53 -04:00
c228a7298a
Remove message history 2025-08-01 14:21:18 -04:00
1c75cfc716
Improve LangChain memory implementation
- Swap to RunnableWithMemory
- Add verbosity flag
2025-07-18 22:21:31 -04:00
5 changed files with 57 additions and 83 deletions

View file

@ -69,27 +69,24 @@ Examples:
default=( default=(
"You are a PR review assistant in charge of software quality control. " "You are a PR review assistant in charge of software quality control. "
"You analyze code changes in the context of the full code base to verify style, " "You analyze code changes in the context of the full code base to verify style, "
"syntax, and functionality. Your response should contain exactly 3 suggestions. " "syntax, and functionality. You respond with suggestions to improve code quality "
"Each suggestion must be in the following format:\n" "You only provide sugestions when you find a flaw in the code otherwise you say that "
"```diff\n" "no issues were found. Each suggestion should reference the old code and the new "
"-<old code>\n" "suggested code."
"+<new code>\n" "Do not provide an analysis of the code and do not summarize suggestions. "
"```\n" "Answer as briefly as possible and return only the suggestions in the requested format "
"Reason: <explanation>\n\n" "with bullet points and no extra text. Provide examples when appropriate."
"Here are two examples of the required format:\n"
"```diff\n"
"-somvr = 2 + 2\n"
"+somevar = 2 + 2\n"
"```\n"
"Reason: somvr is likely a typo, try replacing with somevar\n\n"
"```diff\n"
"-add_two_numbers(\"1\", \"2\")\n"
"+add_two_numbers(1,2)\n"
"```\n"
"Reason: add_two_numbers requires numeric values and does not accept strings"
), ),
help="Base branch to compare against (default: %(default)s)", help="Base branch to compare against (default: %(default)s)",
) )
parser.add_argument(
"-v",
"--verbose",
action="store_true",
default=False,
help="Enable verbose output, including messages sent to the LLM.",
)
return parser return parser

View file

@ -13,7 +13,7 @@ class OllamaConfig:
base_url: str base_url: str
system_prompt: str system_prompt: str
# TODO: Update this to be a passed in value # TODO: Update this to be a passed in value
temperature: float = field(default=0.7) temperature: float = field(default=0.0)
@dataclass(frozen=True) @dataclass(frozen=True)
@ -23,13 +23,14 @@ class ReviewConfig:
paths: List[Path] paths: List[Path]
ollama: OllamaConfig ollama: OllamaConfig
base_branch: str base_branch: str
verbose: bool
def create_ollama_config( def create_ollama_config(
model: str, model: str,
server_url: str, server_url: str,
system_prompt: str, system_prompt: str,
temperature=0.7, temperature=0.0,
embedding_model="nomic-embed-text", embedding_model="nomic-embed-text",
) -> OllamaConfig: ) -> OllamaConfig:
"""Create OllamaConfig with validated parameters.""" """Create OllamaConfig with validated parameters."""
@ -43,19 +44,24 @@ def create_ollama_config(
def create_review_config( def create_review_config(
paths: List[Path], ollama_config: OllamaConfig, base_branch: str paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose
) -> ReviewConfig: ) -> ReviewConfig:
"""Create complete ReviewConfig from validated components.""" """Create complete ReviewConfig from validated components."""
return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch) return ReviewConfig(
paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose
)
def namespace_to_config( def namespace_to_config(namespace: argparse.Namespace):
namespace: argparse.Namespace
):
"""Transform argparse namespace into ReviewConfig.""" """Transform argparse namespace into ReviewConfig."""
paths = [Path(path_str) for path_str in namespace.paths] paths = [Path(path_str) for path_str in namespace.paths]
ollama_config = OllamaConfig( ollama_config = OllamaConfig(
chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model chat_model=namespace.model,
base_url=namespace.server_url,
system_prompt=namespace.system_prompt,
embedding_model=namespace.embedding_model,
) )
return create_review_config(paths, ollama_config, namespace.base_branch) return create_review_config(
paths, ollama_config, namespace.base_branch, namespace.verbose
)

View file

@ -1,9 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any from typing import Any
from langchain.memory import ConversationBufferMemory from langchain_core.messages import BaseMessage
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.prompts import ChatPromptTemplate
from langchain.schema import BaseMessage
from langchain_core.runnables import RunnableLambda from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.passthrough import RunnablePassthrough from langchain_core.runnables.passthrough import RunnablePassthrough
@ -17,20 +16,13 @@ from .configs import OllamaConfig
@dataclass(frozen=True) @dataclass(frozen=True)
class ChatClient: class ChatClient:
chain: RunnableSerializable[dict[str, Any], BaseMessage] chain: RunnableSerializable
memory: ConversationBufferMemory
def get_last_response_or_none(self):
try:
return self.memory.chat_memory.messages[-1]
except IndexError:
return None
def create_chat_chain( def create_chat_chain(
config: OllamaConfig, config: OllamaConfig,
) -> RunnableSerializable[dict[str, Any], BaseMessage]: ) -> RunnableSerializable[dict[str, Any], BaseMessage]:
"""Create the chat chain for use by the code.""" """Create the base chat chain for use by the code."""
llm = ChatOllama( llm = ChatOllama(
model=config.chat_model, model=config.chat_model,
base_url=config.base_url, base_url=config.base_url,
@ -40,28 +32,16 @@ def create_chat_chain(
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [
("system", config.system_prompt), ("system", config.system_prompt),
MessagesPlaceholder("chat_history"),
("human", "Context\n{context}\n\nQuestion:\n{input}"), ("human", "Context\n{context}\n\nQuestion:\n{input}"),
] ]
) )
def get_chat_history(inputs: dict) -> list:
"""Extract chat history from memory object."""
try:
return inputs["memory"].chat_memory.messages
except AttributeError:
return []
def get_context(inputs: dict) -> str: def get_context(inputs: dict) -> str:
"""Extract the RAG context from the input object""" """Extract the RAG context from the input object"""
try: return inputs.get("context", "")
return inputs["context"]
except AttributeError:
return ""
return ( return (
RunnablePassthrough.assign( RunnablePassthrough.assign(
chat_history=RunnableLambda(get_chat_history),
context=RunnableLambda(get_context), context=RunnableLambda(get_context),
) )
| prompt | prompt
@ -69,32 +49,28 @@ def create_chat_chain(
) )
def create_memory():
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def create_chat_client(config: OllamaConfig): def create_chat_client(config: OllamaConfig):
base_chain = create_chat_chain(config)
return ChatClient( return ChatClient(
chain=create_chat_chain(config), chain=base_chain,
memory=create_memory(),
) )
def chat_with_client( def chat_with_client(
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None client: ChatClient,
) -> ChatClient: message: str,
retriever: VectorStoreRetriever | None = None,
verbose: bool = False,
) -> str:
"""Chat with the client and return the response content."""
if retriever: if retriever:
context = get_context_from_store(message, retriever) context = get_context_from_store(message, retriever)
else: else:
context = "" context = ""
response = client.chain.invoke( response = client.chain.invoke(
{"input": message, "memory": client.memory, "context": context} {"input": message, "context": context},
) )
memory = client.memory return response.content
memory.chat_memory.add_user_message(message)
memory.chat_memory.add_ai_message(response.content)
return ChatClient(chain=client.chain, memory=memory)

View file

@ -19,7 +19,7 @@ def run_reviewllama(config: ReviewConfig):
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
for diff in analysis.diffs: for diff in analysis.diffs:
chat_client = get_suggestions(diff, retriever, chat_client) get_suggestions(diff, retriever, chat_client, config.verbose)
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
@ -44,20 +44,16 @@ def create_and_log_vector_store_retriever(
def get_suggestions( def get_suggestions(
diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient diff: GitDiff,
) -> ChatClient: retriever: VectorStoreRetriever,
new_client = chat_with_client(chat_client, craft_message(diff), retriever) chat_client: ChatClient,
log_info(str(new_client.get_last_response_or_none().content)) verbose: bool,
return new_client ):
response = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
log_info(response)
def craft_message(diff) -> str: def craft_message(diff) -> str:
return ( return (
"Review the following code changes and make up to three suggestions on "
"how to improve it. If the code is sufficiently simple or accurate then say "
"no suggestions can be found. Important issues you should consider are consistent "
"style, introduction of syntax errors, and potentially breaking changes in "
"interfaces/APIs that aren't properly handled.\n\n"
f"The original code:\n```\n{diff.old_content}\n```\n"
f"The new code:\n```\n{diff.new_content}```" f"The new code:\n```\n{diff.new_content}```"
) )

View file

@ -17,11 +17,10 @@ def test_chat_client(ollama_config, chat_client):
if not is_ollama_available(ollama_config): if not is_ollama_available(ollama_config):
pytest.skip("Local Ollama server is not available") pytest.skip("Local Ollama server is not available")
chat_client = chat_with_client( response = chat_with_client(
chat_client, "Tell me your name and introduce yourself briefly" chat_client, "Tell me your name and introduce yourself briefly"
) )
response = chat_client.get_last_response_or_none()
assert response is not None assert response is not None
assert len(response.content) > 0 assert len(response) > 0
assert "gemma" in response.content.lower() assert "gemma" in response.lower()