diff --git a/src/reviewllama/cli.py b/src/reviewllama/cli.py index 43fa565..bdac075 100644 --- a/src/reviewllama/cli.py +++ b/src/reviewllama/cli.py @@ -69,27 +69,24 @@ Examples: default=( "You are a PR review assistant in charge of software quality control. " "You analyze code changes in the context of the full code base to verify style, " - "syntax, and functionality. Your response should contain exactly 3 suggestions. " - "Each suggestion must be in the following format:\n" - "```diff\n" - "-\n" - "+\n" - "```\n" - "Reason: \n\n" - "Here are two examples of the required format:\n" - "```diff\n" - "-somvr = 2 + 2\n" - "+somevar = 2 + 2\n" - "```\n" - "Reason: somvr is likely a typo, try replacing with somevar\n\n" - "```diff\n" - "-add_two_numbers(\"1\", \"2\")\n" - "+add_two_numbers(1,2)\n" - "```\n" - "Reason: add_two_numbers requires numeric values and does not accept strings" + "syntax, and functionality. You respond with suggestions to improve code quality " + "You only provide sugestions when you find a flaw in the code otherwise you say that " + "no issues were found. Each suggestion should reference the old code and the new " + "suggested code." + "Do not provide an analysis of the code and do not summarize suggestions. " + "Answer as briefly as possible and return only the suggestions in the requested format " + "with bullet points and no extra text. Provide examples when appropriate." ), help="Base branch to compare against (default: %(default)s)", ) + + parser.add_argument( + "-v", + "--verbose", + action="store_true", + default=False, + help="Enable verbose output, including messages sent to the LLM.", + ) return parser diff --git a/src/reviewllama/configs.py b/src/reviewllama/configs.py index d577915..bce1009 100644 --- a/src/reviewllama/configs.py +++ b/src/reviewllama/configs.py @@ -13,7 +13,7 @@ class OllamaConfig: base_url: str system_prompt: str # TODO: Update this to be a passed in value - temperature: float = field(default=0.7) + temperature: float = field(default=0.0) @dataclass(frozen=True) @@ -23,13 +23,14 @@ class ReviewConfig: paths: List[Path] ollama: OllamaConfig base_branch: str + verbose: bool def create_ollama_config( model: str, server_url: str, system_prompt: str, - temperature=0.7, + temperature=0.0, embedding_model="nomic-embed-text", ) -> OllamaConfig: """Create OllamaConfig with validated parameters.""" @@ -43,19 +44,24 @@ def create_ollama_config( def create_review_config( - paths: List[Path], ollama_config: OllamaConfig, base_branch: str + paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose ) -> ReviewConfig: """Create complete ReviewConfig from validated components.""" - return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch) + return ReviewConfig( + paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose + ) -def namespace_to_config( - namespace: argparse.Namespace -): +def namespace_to_config(namespace: argparse.Namespace): """Transform argparse namespace into ReviewConfig.""" paths = [Path(path_str) for path_str in namespace.paths] ollama_config = OllamaConfig( - chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model + chat_model=namespace.model, + base_url=namespace.server_url, + system_prompt=namespace.system_prompt, + embedding_model=namespace.embedding_model, ) - return create_review_config(paths, ollama_config, namespace.base_branch) + return create_review_config( + paths, ollama_config, namespace.base_branch, namespace.verbose + ) diff --git a/src/reviewllama/llm.py b/src/reviewllama/llm.py index 73a7265..99e29c4 100644 --- a/src/reviewllama/llm.py +++ b/src/reviewllama/llm.py @@ -1,9 +1,8 @@ from dataclasses import dataclass from typing import Any -from langchain.memory import ConversationBufferMemory -from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder -from langchain.schema import BaseMessage +from langchain_core.messages import BaseMessage +from langchain_core.prompts import ChatPromptTemplate from langchain_core.runnables import RunnableLambda from langchain_core.runnables.base import RunnableSerializable from langchain_core.runnables.passthrough import RunnablePassthrough @@ -17,20 +16,13 @@ from .configs import OllamaConfig @dataclass(frozen=True) class ChatClient: - chain: RunnableSerializable[dict[str, Any], BaseMessage] - memory: ConversationBufferMemory - - def get_last_response_or_none(self): - try: - return self.memory.chat_memory.messages[-1] - except IndexError: - return None + chain: RunnableSerializable def create_chat_chain( config: OllamaConfig, ) -> RunnableSerializable[dict[str, Any], BaseMessage]: - """Create the chat chain for use by the code.""" + """Create the base chat chain for use by the code.""" llm = ChatOllama( model=config.chat_model, base_url=config.base_url, @@ -40,28 +32,16 @@ def create_chat_chain( prompt = ChatPromptTemplate.from_messages( [ ("system", config.system_prompt), - MessagesPlaceholder("chat_history"), ("human", "Context\n{context}\n\nQuestion:\n{input}"), ] ) - def get_chat_history(inputs: dict) -> list: - """Extract chat history from memory object.""" - try: - return inputs["memory"].chat_memory.messages - except AttributeError: - return [] - def get_context(inputs: dict) -> str: """Extract the RAG context from the input object""" - try: - return inputs["context"] - except AttributeError: - return "" + return inputs.get("context", "") return ( RunnablePassthrough.assign( - chat_history=RunnableLambda(get_chat_history), context=RunnableLambda(get_context), ) | prompt @@ -69,32 +49,28 @@ def create_chat_chain( ) -def create_memory(): - return ConversationBufferMemory(memory_key="chat_history", return_messages=True) - - def create_chat_client(config: OllamaConfig): + base_chain = create_chat_chain(config) + return ChatClient( - chain=create_chat_chain(config), - memory=create_memory(), + chain=base_chain, ) def chat_with_client( - client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None -) -> ChatClient: - + client: ChatClient, + message: str, + retriever: VectorStoreRetriever | None = None, + verbose: bool = False, +) -> str: + """Chat with the client and return the response content.""" if retriever: context = get_context_from_store(message, retriever) else: context = "" response = client.chain.invoke( - {"input": message, "memory": client.memory, "context": context} + {"input": message, "context": context}, ) - memory = client.memory - memory.chat_memory.add_user_message(message) - memory.chat_memory.add_ai_message(response.content) - - return ChatClient(chain=client.chain, memory=memory) + return response.content diff --git a/src/reviewllama/reviewllama.py b/src/reviewllama/reviewllama.py index ae0cbba..bab447d 100644 --- a/src/reviewllama/reviewllama.py +++ b/src/reviewllama/reviewllama.py @@ -19,7 +19,7 @@ def run_reviewllama(config: ReviewConfig): retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama) for diff in analysis.diffs: - chat_client = get_suggestions(diff, retriever, chat_client) + get_suggestions(diff, retriever, chat_client, config.verbose) def create_and_log_chat_client(config: OllamaConfig) -> ChatClient: @@ -44,20 +44,16 @@ def create_and_log_vector_store_retriever( def get_suggestions( - diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient -) -> ChatClient: - new_client = chat_with_client(chat_client, craft_message(diff), retriever) - log_info(str(new_client.get_last_response_or_none().content)) - return new_client + diff: GitDiff, + retriever: VectorStoreRetriever, + chat_client: ChatClient, + verbose: bool, +): + response = chat_with_client(chat_client, craft_message(diff), retriever, verbose) + log_info(response) def craft_message(diff) -> str: return ( - "Review the following code changes and make up to three suggestions on " - "how to improve it. If the code is sufficiently simple or accurate then say " - "no suggestions can be found. Important issues you should consider are consistent " - "style, introduction of syntax errors, and potentially breaking changes in " - "interfaces/APIs that aren't properly handled.\n\n" - f"The original code:\n```\n{diff.old_content}\n```\n" f"The new code:\n```\n{diff.new_content}```" ) diff --git a/tests/test_llm.py b/tests/test_llm.py index faa3a4e..77fe6f5 100644 --- a/tests/test_llm.py +++ b/tests/test_llm.py @@ -17,11 +17,10 @@ def test_chat_client(ollama_config, chat_client): if not is_ollama_available(ollama_config): pytest.skip("Local Ollama server is not available") - chat_client = chat_with_client( + response = chat_with_client( chat_client, "Tell me your name and introduce yourself briefly" ) - response = chat_client.get_last_response_or_none() assert response is not None - assert len(response.content) > 0 - assert "gemma" in response.content.lower() + assert len(response) > 0 + assert "gemma" in response.lower()