78 lines
2.1 KiB
Python
78 lines
2.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from langchain_core.messages import BaseMessage
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
|
from langchain_core.runnables import RunnableLambda
|
|
from langchain_core.runnables.base import RunnableSerializable
|
|
from langchain_core.runnables.passthrough import RunnablePassthrough
|
|
from langchain_core.vectorstores import VectorStoreRetriever
|
|
from langchain_ollama import ChatOllama
|
|
|
|
from reviewllama.vector_store import get_context_from_store
|
|
|
|
from .configs import OllamaConfig
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ChatClient:
|
|
chain: RunnableSerializable
|
|
|
|
|
|
def create_chat_chain(
|
|
config: OllamaConfig,
|
|
) -> RunnableSerializable[dict[str, Any], BaseMessage]:
|
|
"""Create the base chat chain for use by the code."""
|
|
llm = ChatOllama(
|
|
model=config.chat_model,
|
|
base_url=config.base_url,
|
|
temperature=config.temperature,
|
|
)
|
|
|
|
prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
("system", config.system_prompt),
|
|
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
|
]
|
|
)
|
|
|
|
def get_context(inputs: dict) -> str:
|
|
"""Extract the RAG context from the input object"""
|
|
return inputs.get("context", "")
|
|
|
|
return (
|
|
RunnablePassthrough.assign(
|
|
context=RunnableLambda(get_context),
|
|
)
|
|
| prompt
|
|
| llm
|
|
)
|
|
|
|
|
|
def create_chat_client(config: OllamaConfig):
|
|
base_chain = create_chat_chain(config)
|
|
|
|
return ChatClient(
|
|
chain=base_chain,
|
|
)
|
|
|
|
|
|
def chat_with_client(
|
|
client: ChatClient,
|
|
message: str,
|
|
retriever: VectorStoreRetriever | None = None,
|
|
session_id: str = "default",
|
|
verbose: bool = False,
|
|
) -> str:
|
|
"""Chat with the client and return the response content."""
|
|
if retriever:
|
|
context = get_context_from_store(message, retriever)
|
|
else:
|
|
context = ""
|
|
|
|
response = client.chain.invoke(
|
|
{"input": message, "context": context},
|
|
config={"configurable": {"session_id": session_id}},
|
|
)
|
|
|
|
return response.content
|