Compare commits
3 commits
e59cf01ba9
...
f17b6721f9
Author | SHA1 | Date | |
---|---|---|---|
f17b6721f9 | |||
c228a7298a | |||
1c75cfc716 |
5 changed files with 57 additions and 83 deletions
|
@ -69,27 +69,24 @@ Examples:
|
||||||
default=(
|
default=(
|
||||||
"You are a PR review assistant in charge of software quality control. "
|
"You are a PR review assistant in charge of software quality control. "
|
||||||
"You analyze code changes in the context of the full code base to verify style, "
|
"You analyze code changes in the context of the full code base to verify style, "
|
||||||
"syntax, and functionality. Your response should contain exactly 3 suggestions. "
|
"syntax, and functionality. You respond with suggestions to improve code quality "
|
||||||
"Each suggestion must be in the following format:\n"
|
"You only provide sugestions when you find a flaw in the code otherwise you say that "
|
||||||
"```diff\n"
|
"no issues were found. Each suggestion should reference the old code and the new "
|
||||||
"-<old code>\n"
|
"suggested code."
|
||||||
"+<new code>\n"
|
"Do not provide an analysis of the code and do not summarize suggestions. "
|
||||||
"```\n"
|
"Answer as briefly as possible and return only the suggestions in the requested format "
|
||||||
"Reason: <explanation>\n\n"
|
"with bullet points and no extra text. Provide examples when appropriate."
|
||||||
"Here are two examples of the required format:\n"
|
|
||||||
"```diff\n"
|
|
||||||
"-somvr = 2 + 2\n"
|
|
||||||
"+somevar = 2 + 2\n"
|
|
||||||
"```\n"
|
|
||||||
"Reason: somvr is likely a typo, try replacing with somevar\n\n"
|
|
||||||
"```diff\n"
|
|
||||||
"-add_two_numbers(\"1\", \"2\")\n"
|
|
||||||
"+add_two_numbers(1,2)\n"
|
|
||||||
"```\n"
|
|
||||||
"Reason: add_two_numbers requires numeric values and does not accept strings"
|
|
||||||
),
|
),
|
||||||
help="Base branch to compare against (default: %(default)s)",
|
help="Base branch to compare against (default: %(default)s)",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Enable verbose output, including messages sent to the LLM.",
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ class OllamaConfig:
|
||||||
base_url: str
|
base_url: str
|
||||||
system_prompt: str
|
system_prompt: str
|
||||||
# TODO: Update this to be a passed in value
|
# TODO: Update this to be a passed in value
|
||||||
temperature: float = field(default=0.7)
|
temperature: float = field(default=0.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
|
@ -23,13 +23,14 @@ class ReviewConfig:
|
||||||
paths: List[Path]
|
paths: List[Path]
|
||||||
ollama: OllamaConfig
|
ollama: OllamaConfig
|
||||||
base_branch: str
|
base_branch: str
|
||||||
|
verbose: bool
|
||||||
|
|
||||||
|
|
||||||
def create_ollama_config(
|
def create_ollama_config(
|
||||||
model: str,
|
model: str,
|
||||||
server_url: str,
|
server_url: str,
|
||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
temperature=0.7,
|
temperature=0.0,
|
||||||
embedding_model="nomic-embed-text",
|
embedding_model="nomic-embed-text",
|
||||||
) -> OllamaConfig:
|
) -> OllamaConfig:
|
||||||
"""Create OllamaConfig with validated parameters."""
|
"""Create OllamaConfig with validated parameters."""
|
||||||
|
@ -43,19 +44,24 @@ def create_ollama_config(
|
||||||
|
|
||||||
|
|
||||||
def create_review_config(
|
def create_review_config(
|
||||||
paths: List[Path], ollama_config: OllamaConfig, base_branch: str
|
paths: List[Path], ollama_config: OllamaConfig, base_branch: str, verbose
|
||||||
) -> ReviewConfig:
|
) -> ReviewConfig:
|
||||||
"""Create complete ReviewConfig from validated components."""
|
"""Create complete ReviewConfig from validated components."""
|
||||||
return ReviewConfig(paths=paths, ollama=ollama_config, base_branch=base_branch)
|
return ReviewConfig(
|
||||||
|
paths=paths, ollama=ollama_config, base_branch=base_branch, verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def namespace_to_config(
|
def namespace_to_config(namespace: argparse.Namespace):
|
||||||
namespace: argparse.Namespace
|
|
||||||
):
|
|
||||||
"""Transform argparse namespace into ReviewConfig."""
|
"""Transform argparse namespace into ReviewConfig."""
|
||||||
paths = [Path(path_str) for path_str in namespace.paths]
|
paths = [Path(path_str) for path_str in namespace.paths]
|
||||||
ollama_config = OllamaConfig(
|
ollama_config = OllamaConfig(
|
||||||
chat_model=namespace.model, base_url=namespace.server_url, system_prompt=namespace.system_prompt, embedding_model=namespace.embedding_model
|
chat_model=namespace.model,
|
||||||
|
base_url=namespace.server_url,
|
||||||
|
system_prompt=namespace.system_prompt,
|
||||||
|
embedding_model=namespace.embedding_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
return create_review_config(paths, ollama_config, namespace.base_branch)
|
return create_review_config(
|
||||||
|
paths, ollama_config, namespace.base_branch, namespace.verbose
|
||||||
|
)
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain_core.messages import BaseMessage
|
||||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
from langchain_core.prompts import ChatPromptTemplate
|
||||||
from langchain.schema import BaseMessage
|
|
||||||
from langchain_core.runnables import RunnableLambda
|
from langchain_core.runnables import RunnableLambda
|
||||||
from langchain_core.runnables.base import RunnableSerializable
|
from langchain_core.runnables.base import RunnableSerializable
|
||||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||||
|
@ -17,20 +16,13 @@ from .configs import OllamaConfig
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class ChatClient:
|
class ChatClient:
|
||||||
chain: RunnableSerializable[dict[str, Any], BaseMessage]
|
chain: RunnableSerializable
|
||||||
memory: ConversationBufferMemory
|
|
||||||
|
|
||||||
def get_last_response_or_none(self):
|
|
||||||
try:
|
|
||||||
return self.memory.chat_memory.messages[-1]
|
|
||||||
except IndexError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_chain(
|
def create_chat_chain(
|
||||||
config: OllamaConfig,
|
config: OllamaConfig,
|
||||||
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
||||||
"""Create the chat chain for use by the code."""
|
"""Create the base chat chain for use by the code."""
|
||||||
llm = ChatOllama(
|
llm = ChatOllama(
|
||||||
model=config.chat_model,
|
model=config.chat_model,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
|
@ -40,28 +32,16 @@ def create_chat_chain(
|
||||||
prompt = ChatPromptTemplate.from_messages(
|
prompt = ChatPromptTemplate.from_messages(
|
||||||
[
|
[
|
||||||
("system", config.system_prompt),
|
("system", config.system_prompt),
|
||||||
MessagesPlaceholder("chat_history"),
|
|
||||||
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_chat_history(inputs: dict) -> list:
|
|
||||||
"""Extract chat history from memory object."""
|
|
||||||
try:
|
|
||||||
return inputs["memory"].chat_memory.messages
|
|
||||||
except AttributeError:
|
|
||||||
return []
|
|
||||||
|
|
||||||
def get_context(inputs: dict) -> str:
|
def get_context(inputs: dict) -> str:
|
||||||
"""Extract the RAG context from the input object"""
|
"""Extract the RAG context from the input object"""
|
||||||
try:
|
return inputs.get("context", "")
|
||||||
return inputs["context"]
|
|
||||||
except AttributeError:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
RunnablePassthrough.assign(
|
RunnablePassthrough.assign(
|
||||||
chat_history=RunnableLambda(get_chat_history),
|
|
||||||
context=RunnableLambda(get_context),
|
context=RunnableLambda(get_context),
|
||||||
)
|
)
|
||||||
| prompt
|
| prompt
|
||||||
|
@ -69,32 +49,28 @@ def create_chat_chain(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_memory():
|
|
||||||
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
|
||||||
|
|
||||||
|
|
||||||
def create_chat_client(config: OllamaConfig):
|
def create_chat_client(config: OllamaConfig):
|
||||||
|
base_chain = create_chat_chain(config)
|
||||||
|
|
||||||
return ChatClient(
|
return ChatClient(
|
||||||
chain=create_chat_chain(config),
|
chain=base_chain,
|
||||||
memory=create_memory(),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def chat_with_client(
|
def chat_with_client(
|
||||||
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
|
client: ChatClient,
|
||||||
) -> ChatClient:
|
message: str,
|
||||||
|
retriever: VectorStoreRetriever | None = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Chat with the client and return the response content."""
|
||||||
if retriever:
|
if retriever:
|
||||||
context = get_context_from_store(message, retriever)
|
context = get_context_from_store(message, retriever)
|
||||||
else:
|
else:
|
||||||
context = ""
|
context = ""
|
||||||
|
|
||||||
response = client.chain.invoke(
|
response = client.chain.invoke(
|
||||||
{"input": message, "memory": client.memory, "context": context}
|
{"input": message, "context": context},
|
||||||
)
|
)
|
||||||
|
|
||||||
memory = client.memory
|
return response.content
|
||||||
memory.chat_memory.add_user_message(message)
|
|
||||||
memory.chat_memory.add_ai_message(response.content)
|
|
||||||
|
|
||||||
return ChatClient(chain=client.chain, memory=memory)
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ def run_reviewllama(config: ReviewConfig):
|
||||||
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
|
retriever = create_and_log_vector_store_retriever(analysis.repo, config.ollama)
|
||||||
|
|
||||||
for diff in analysis.diffs:
|
for diff in analysis.diffs:
|
||||||
chat_client = get_suggestions(diff, retriever, chat_client)
|
get_suggestions(diff, retriever, chat_client, config.verbose)
|
||||||
|
|
||||||
|
|
||||||
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
|
def create_and_log_chat_client(config: OllamaConfig) -> ChatClient:
|
||||||
|
@ -44,20 +44,16 @@ def create_and_log_vector_store_retriever(
|
||||||
|
|
||||||
|
|
||||||
def get_suggestions(
|
def get_suggestions(
|
||||||
diff: GitDiff, retriever: VectorStoreRetriever, chat_client: ChatClient
|
diff: GitDiff,
|
||||||
) -> ChatClient:
|
retriever: VectorStoreRetriever,
|
||||||
new_client = chat_with_client(chat_client, craft_message(diff), retriever)
|
chat_client: ChatClient,
|
||||||
log_info(str(new_client.get_last_response_or_none().content))
|
verbose: bool,
|
||||||
return new_client
|
):
|
||||||
|
response = chat_with_client(chat_client, craft_message(diff), retriever, verbose)
|
||||||
|
log_info(response)
|
||||||
|
|
||||||
|
|
||||||
def craft_message(diff) -> str:
|
def craft_message(diff) -> str:
|
||||||
return (
|
return (
|
||||||
"Review the following code changes and make up to three suggestions on "
|
|
||||||
"how to improve it. If the code is sufficiently simple or accurate then say "
|
|
||||||
"no suggestions can be found. Important issues you should consider are consistent "
|
|
||||||
"style, introduction of syntax errors, and potentially breaking changes in "
|
|
||||||
"interfaces/APIs that aren't properly handled.\n\n"
|
|
||||||
f"The original code:\n```\n{diff.old_content}\n```\n"
|
|
||||||
f"The new code:\n```\n{diff.new_content}```"
|
f"The new code:\n```\n{diff.new_content}```"
|
||||||
)
|
)
|
||||||
|
|
|
@ -17,11 +17,10 @@ def test_chat_client(ollama_config, chat_client):
|
||||||
if not is_ollama_available(ollama_config):
|
if not is_ollama_available(ollama_config):
|
||||||
pytest.skip("Local Ollama server is not available")
|
pytest.skip("Local Ollama server is not available")
|
||||||
|
|
||||||
chat_client = chat_with_client(
|
response = chat_with_client(
|
||||||
chat_client, "Tell me your name and introduce yourself briefly"
|
chat_client, "Tell me your name and introduce yourself briefly"
|
||||||
)
|
)
|
||||||
response = chat_client.get_last_response_or_none()
|
|
||||||
|
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert len(response.content) > 0
|
assert len(response) > 0
|
||||||
assert "gemma" in response.content.lower()
|
assert "gemma" in response.lower()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue