diff --git a/src/reviewllama/cli.py b/src/reviewllama/cli.py index bdac075..43fa565 100644 --- a/src/reviewllama/cli.py +++ b/src/reviewllama/cli.py @@ -69,24 +69,27 @@ Examples: default=( "You are a PR review assistant in charge of software quality control. " "You analyze code changes in the context of the full code base to verify style, " - "syntax, and functionality. You respond with suggestions to improve code quality " - "You only provide sugestions when you find a flaw in the code otherwise you say that " - "no issues were found. Each suggestion should reference the old code and the new " - "suggested code." - "Do not provide an analysis of the code and do not summarize suggestions. " - "Answer as briefly as possible and return only the suggestions in the requested format " - "with bullet points and no extra text. Provide examples when appropriate." + "syntax, and functionality. Your response should contain exactly 3 suggestions. " + "Each suggestion must be in the following format:\n" + "```diff\n" + "-\n" + "+\n" + "```\n" + "Reason: \n\n" + "Here are two examples of the required format:\n" + "```diff\n" + "-somvr = 2 + 2\n" + "+somevar = 2 + 2\n" + "```\n" + "Reason: somvr is likely a typo, try replacing with somevar\n\n" + "```diff\n" + "-add_two_numbers(\"1\", \"2\")\n" + "+add_two_numbers(1,2)\n" + "```\n" + "Reason: add_two_numbers requires numeric values and does not accept strings" ), help="Base branch to compare against (default: %(default)s)", ) - - parser.add_argument( - "-v", - "--verbose", - action="store_true", - default=False, - help="Enable verbose output, including messages sent to the LLM.", - ) return parser diff --git a/src/reviewllama/configs.py b/src/reviewllama/configs.py index bce1009..d577915 100644 --- a/src/reviewllama/configs.py +++ b/src/reviewllama/configs.py @@ -13,7 +13,7 @@ class OllamaConfig: base_url: str system_prompt: str # TODO: Update this to be a passed in value - temperature: float = field(default=0.0) + temperature: float = field(default=0.7) @dataclass(frozen=True) @@ -23,14 +23,13 @@ class ReviewConfig: paths: List[Path] ollama: OllamaConfig base_branch: str - verbose: bool def create_ollama_config( model: str, server_url: str, system_prompt: str, - temperature=0.0, + temperature=0.7, embedding_model="nomic-embed-text", ) -> OllamaConfig: """Create OllamaConfig with validated parameters.""" @@ -44,24 +43,19 @@ def create_ollama_config( def create_review_config( - paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose + paths: List[Path], ollama_config: OllamaConfig, base_branch: str ) -> ReviewConfig: """Create complete ReviewConfig from validated components.""" - return ReviewConfig( - paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose - ) + return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch) -def namespace_to_config(namespace: argparse.Namespace): +def namespace_to_config( + namespace: argparse.Namespace +): """Transform argparse namespace into ReviewConfig.""" paths = [Path(path_str) for path_str in namespace.paths] ollama_config = OllamaConfig( - chat_model=namespace.model, - base_url=namespace.server_url, - system_prompt=namespace.system_prompt, - embedding_model=namespace.embedding_model, + chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model ) - return create_review_config( - paths, ollama_config, namespace.base_branch, namespace.verbose - ) + return create_review_config(paths, ollama_config, namespace.base_branch) diff --git a/src/reviewllama/llm.py b/src/reviewllama/llm.py index 99e29c4..73a7265 100644 --- a/src/reviewllama/llm.py +++ b/src/reviewllama/llm.py @@ -1,8 +1,9 @@ from dataclasses import dataclass from typing import Any -from langchain_core.messages import BaseMessage -from langchain_core.prompts import ChatPromptTemplate +from langchain.memory import ConversationBufferMemory +from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder +from langchain.schema import BaseMessage from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.passthrough import RunnablePassthrough @@ -16,13 +17,20 @@ from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: - chain: RunnableSerializable + chain: RunnableSerializable[dict[str, Any], BaseMessage] + memory: ConversationBufferMemory + + def get_last_response_or_none(self): + try: + return self.memory.chat_memory.messages[-1] + except IndexError: + return None def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: - """Create the base chat chain for use by the code.""" + """Create the chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, @@ -32,16 +40,28 @@ def create_chat_chain( prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), + MessagesPlaceholder("chat_history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) + def get_chat_history(inputs: dict) -> list: + """Extract chat history from memory object.""" + try: + return inputs["memory"].chat_memory.messages + except AttributeError: + return [] + def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" - return inputs.get("context", "") + try: + return inputs["context"] + except AttributeError: + return "" return ( RunnablePassthrough.assign( + chat_history=RunnableLambda(get_chat_history), context=RunnableLambda(get_context), ) | prompt @@ -49,28 +69,32 @@ def create_chat_chain( ) -def create_chat_client(config: OllamaConfig): - base_chain = create_chat_chain(config) +def create_memory(): + return ConversationBufferMemory(memory_key="chat_history", return_messages=True) + +def create_chat_client(config: OllamaConfig): return ChatClient( - chain=base_chain, + chain=create_chat_chain(config), + memory=create_memory(), ) def chat_with_client( - client: ChatClient, - message: str, - retriever: VectorStoreRetriever | None = None, - verbose: bool = False, -) -> str: - """Chat with the client and return the response content.""" + client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None +) -> ChatClient: + if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( - {"input": message, "context": context}, + {"input": message, "memory": client.memory, "context": context} ) - return response.content + memory = client.memory + memory.chat_memory.add_user_message(message) + memory.chat_memory.add_ai_message(response.content) + + return ChatClient(chain=client.chain, memory=memory) diff --git a/src/reviewllama/reviewllama.py b/src/reviewllama/reviewllama.py index bab447d..ae0cbba 100644 --- a/src/reviewllama/reviewllama.py +++ b/src/reviewllama/reviewllama.py @@ -19,7 +19,7 @@ def run_reviewllama(config: ReviewConfig): retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) for diff in analysis.diffs: - get_suggestions(diff, retriever, chat_client, config.verbose) + chat_client = get_suggestions(diff, retriever, chat_client) def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: @@ -44,16 +44,20 @@ def create_and_log_vector_store_retriever( def get_suggestions( - diff: GitDiff, - retriever: VectorStoreRetriever, - chat_client: ChatClient, - verbose: bool, -): - response = chat_with_client(chat_client, craft_message(diff), retriever, verbose) - log_info(response) + diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient +) -> ChatClient: + new_client = chat_with_client(chat_client, craft_message(diff), retriever) + log_info(str(new_client.get_last_response_or_none().content)) + return new_client def craft_message(diff) -> str: return ( + "Review the following code changes and make up to three suggestions on " + "how to improve it. If the code is sufficiently simple or accurate then say " + "no suggestions can be found. Important issues you should consider are consistent " + "style, introduction of syntax errors, and potentially breaking changes in " + "interfaces/APIs that aren't properly handled.\n\n" + f"The original code:\n```\n{diff.old_content}\n```\n" f"The new code:\n```\n{diff.new_content}```" ) diff --git a/tests/test_llm.py b/tests/test_llm.py index 77fe6f5..faa3a4e 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -17,10 +17,11 @@ def test_chat_client(ollama_config, chat_client): if not is_ollama_available(ollama_config): pytest.skip("Local Ollama server is not available") - response = chat_with_client( + chat_client = chat_with_client( chat_client, "Tell me your name and introduce yourself briefly" ) + response = chat_client.get_last_response_or_none() assert response is not None - assert len(response) > 0 - assert "gemma" in response.lower() + assert len(response.content) > 0 + assert "gemma" in response.content.lower()