Reformatting, fixing tests, adding basic RAG pipeline implementation

This commit is contained in:
Alex Selimov 2025-07-05 15:16:18 -04:00
parent a6cdbf1761
commit 24bfef99a2
12 changed files with 721 additions and 131 deletions

101
src/reviewllama/llm.py Normal file
View file

@ -0,0 +1,101 @@
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.schema import BaseMessage
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.base import RunnableSerializable
from langchain_core.runnables.passthrough import RunnablePassthrough
from langchain_core.vectorstores import VectorStoreRetriever
from langchain_ollama import ChatOllama
from reviewllama.vector_store import get_context_from_store
from .configs import OllamaConfig
@dataclass(frozen=True)
class ChatClient:
chain: RunnableSerializable[dict[str, Any], BaseMessage]
memory: ConversationBufferMemory
def get_last_response_or_none(self):
try:
return self.memory.chat_memory.messages[-1]
except IndexError:
return None
def create_chat_chain(
config: OllamaConfig,
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
"""Create the chat chain for use by the code."""
llm = ChatOllama(
model=config.chat_model,
base_url=config.base_url,
temperature=config.temperature,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", config.system_prompt),
MessagesPlaceholder("chat_history"),
("human", "Context\n{context}\n\nQuestion:\n{input}"),
]
)
def get_chat_history(inputs: dict) -> list:
"""Extract chat history from memory object."""
try:
return inputs["memory"].chat_memory.messages
except AttributeError:
return []
def get_context(inputs: dict) -> str:
"""Extract the RAG context from the input object"""
try:
return inputs["context"]
except AttributeError:
return ""
return (
RunnablePassthrough.assign(
chat_history=RunnableLambda(get_chat_history),
context=RunnableLambda(get_context),
)
| prompt
| llm
)
def create_memory():
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
def create_chat_client(config: OllamaConfig):
return ChatClient(
chain=create_chat_chain(config),
memory=create_memory(),
)
def chat_with_client(
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
) -> ChatClient:
if retriever:
context = get_context_from_store(message, retriever)
else:
context = ""
response = client.chain.invoke(
{"input": message, "memory": client.memory, "context": context}
)
memory = client.memory
memory.chat_memory.add_user_message(message)
memory.chat_memory.add_ai_message(response.content)
return ChatClient(chain=client.chain, memory=memory)