Reformatting, fixing tests, adding basic RAG pipeline implementation
This commit is contained in:
parent
a6cdbf1761
commit
24bfef99a2
12 changed files with 721 additions and 131 deletions
101
src/reviewllama/llm.py
Normal file
101
src/reviewllama/llm.py
Normal file
|
@ -0,0 +1,101 @@
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain.memory import ConversationBufferMemory
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from langchain_ollama import ChatOllama
|
||||
|
||||
from reviewllama.vector_store import get_context_from_store
|
||||
|
||||
from .configs import OllamaConfig
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ChatClient:
|
||||
chain: RunnableSerializable[dict[str, Any], BaseMessage]
|
||||
memory: ConversationBufferMemory
|
||||
|
||||
def get_last_response_or_none(self):
|
||||
try:
|
||||
return self.memory.chat_memory.messages[-1]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
config: OllamaConfig,
|
||||
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
||||
"""Create the chat chain for use by the code."""
|
||||
llm = ChatOllama(
|
||||
model=config.chat_model,
|
||||
base_url=config.base_url,
|
||||
temperature=config.temperature,
|
||||
)
|
||||
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", config.system_prompt),
|
||||
MessagesPlaceholder("chat_history"),
|
||||
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||||
]
|
||||
)
|
||||
|
||||
def get_chat_history(inputs: dict) -> list:
|
||||
"""Extract chat history from memory object."""
|
||||
try:
|
||||
return inputs["memory"].chat_memory.messages
|
||||
except AttributeError:
|
||||
return []
|
||||
|
||||
def get_context(inputs: dict) -> str:
|
||||
"""Extract the RAG context from the input object"""
|
||||
try:
|
||||
return inputs["context"]
|
||||
except AttributeError:
|
||||
return ""
|
||||
|
||||
return (
|
||||
RunnablePassthrough.assign(
|
||||
chat_history=RunnableLambda(get_chat_history),
|
||||
context=RunnableLambda(get_context),
|
||||
)
|
||||
| prompt
|
||||
| llm
|
||||
)
|
||||
|
||||
|
||||
def create_memory():
|
||||
return ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
||||
|
||||
|
||||
def create_chat_client(config: OllamaConfig):
|
||||
return ChatClient(
|
||||
chain=create_chat_chain(config),
|
||||
memory=create_memory(),
|
||||
)
|
||||
|
||||
|
||||
def chat_with_client(
|
||||
client: ChatClient, message: str, retriever: VectorStoreRetriever | None = None
|
||||
) -> ChatClient:
|
||||
|
||||
if retriever:
|
||||
context = get_context_from_store(message, retriever)
|
||||
else:
|
||||
context = ""
|
||||
|
||||
response = client.chain.invoke(
|
||||
{"input": message, "memory": client.memory, "context": context}
|
||||
)
|
||||
|
||||
memory = client.memory
|
||||
memory.chat_memory.add_user_message(message)
|
||||
memory.chat_memory.add_ai_message(response.content)
|
||||
|
||||
return ChatClient(chain=client.chain, memory=memory)
|
Loading…
Add table
Add a link
Reference in a new issue