2025-07-05 15:16:18 -04:00
|
|
|
from dataclasses import dataclass
|
|
|
|
from typing import Any
|
|
|
|
|
2025-07-18 22:21:31 -04:00
|
|
|
from langchain_community.chat_message_histories import ChatMessageHistory
|
|
|
|
from langchain_core.chat_history import BaseChatMessageHistory
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
2025-07-05 15:16:18 -04:00
|
|
|
from langchain_core.runnables import RunnableLambda
|
|
|
|
from langchain_core.runnables.base import RunnableSerializable
|
2025-07-18 22:21:31 -04:00
|
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory
|
2025-07-05 15:16:18 -04:00
|
|
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
|
|
|
from langchain_core.vectorstores import VectorStoreRetriever
|
|
|
|
from langchain_ollama import ChatOllama
|
|
|
|
|
|
|
|
from reviewllama.vector_store import get_context_from_store
|
|
|
|
|
|
|
|
from .configs import OllamaConfig
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
class ChatClient:
|
2025-07-18 22:21:31 -04:00
|
|
|
chain: RunnableWithMessageHistory
|
|
|
|
message_history: dict[str, BaseChatMessageHistory]
|
2025-07-05 15:16:18 -04:00
|
|
|
|
2025-07-18 22:21:31 -04:00
|
|
|
def get_last_response_or_none(self, session_id: str = "default"):
|
2025-07-05 15:16:18 -04:00
|
|
|
try:
|
2025-07-18 22:21:31 -04:00
|
|
|
messages = self.message_history[session_id].messages
|
|
|
|
return messages[-1] if messages else None
|
|
|
|
except (KeyError, IndexError):
|
2025-07-05 15:16:18 -04:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def create_chat_chain(
|
|
|
|
config: OllamaConfig,
|
|
|
|
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
2025-07-18 22:21:31 -04:00
|
|
|
"""Create the base chat chain for use by the code."""
|
2025-07-05 15:16:18 -04:00
|
|
|
llm = ChatOllama(
|
|
|
|
model=config.chat_model,
|
|
|
|
base_url=config.base_url,
|
|
|
|
temperature=config.temperature,
|
|
|
|
)
|
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
|
|
[
|
|
|
|
("system", config.system_prompt),
|
2025-07-18 22:21:31 -04:00
|
|
|
MessagesPlaceholder("history"),
|
2025-07-05 15:16:18 -04:00
|
|
|
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
|
|
|
]
|
|
|
|
)
|
|
|
|
|
|
|
|
def get_context(inputs: dict) -> str:
|
|
|
|
"""Extract the RAG context from the input object"""
|
2025-07-18 22:21:31 -04:00
|
|
|
return inputs.get("context", "")
|
2025-07-05 15:16:18 -04:00
|
|
|
|
|
|
|
return (
|
|
|
|
RunnablePassthrough.assign(
|
|
|
|
context=RunnableLambda(get_context),
|
|
|
|
)
|
|
|
|
| prompt
|
|
|
|
| llm
|
|
|
|
)
|
|
|
|
|
|
|
|
|
2025-07-18 22:21:31 -04:00
|
|
|
def get_session_history(
|
|
|
|
session_id: str, message_history: dict[str, BaseChatMessageHistory]
|
|
|
|
) -> BaseChatMessageHistory:
|
|
|
|
"""Get or create chat message history for a session."""
|
|
|
|
if session_id not in message_history:
|
|
|
|
message_history[session_id] = ChatMessageHistory()
|
|
|
|
return message_history[session_id]
|
2025-07-05 15:16:18 -04:00
|
|
|
|
|
|
|
|
|
|
|
def create_chat_client(config: OllamaConfig):
|
2025-07-18 22:21:31 -04:00
|
|
|
base_chain = create_chat_chain(config)
|
|
|
|
message_history = {}
|
|
|
|
|
|
|
|
chain_with_history = RunnableWithMessageHistory(
|
|
|
|
base_chain,
|
|
|
|
lambda session_id: get_session_history(session_id, message_history),
|
|
|
|
input_messages_key="input",
|
|
|
|
history_messages_key="history",
|
|
|
|
)
|
|
|
|
|
2025-07-05 15:16:18 -04:00
|
|
|
return ChatClient(
|
2025-07-18 22:21:31 -04:00
|
|
|
chain=chain_with_history,
|
|
|
|
message_history=message_history,
|
2025-07-05 15:16:18 -04:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def chat_with_client(
|
2025-07-18 22:21:31 -04:00
|
|
|
client: ChatClient,
|
|
|
|
message: str,
|
|
|
|
retriever: VectorStoreRetriever | None = None,
|
|
|
|
session_id: str = "default",
|
|
|
|
verbose: bool = False,
|
|
|
|
) -> str:
|
|
|
|
"""Chat with the client and return the response content."""
|
2025-07-05 15:16:18 -04:00
|
|
|
if retriever:
|
|
|
|
context = get_context_from_store(message, retriever)
|
|
|
|
else:
|
|
|
|
context = ""
|
|
|
|
|
|
|
|
response = client.chain.invoke(
|
2025-07-18 22:21:31 -04:00
|
|
|
{"input": message, "context": context},
|
|
|
|
config={"configurable": {"session_id": session_id}},
|
2025-07-05 15:16:18 -04:00
|
|
|
)
|
|
|
|
|
2025-07-18 22:21:31 -04:00
|
|
|
return response.content
|