Remove message history
This commit is contained in:
parent
1c75cfc716
commit
c228a7298a
3 changed files with 8 additions and 45 deletions
|
@ -1,13 +1,10 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from langchain_community.chat_message_histories import ChatMessageHistory
|
||||
from langchain_core.chat_history import BaseChatMessageHistory
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_core.runnables import RunnableLambda
|
||||
from langchain_core.runnables.base import RunnableSerializable
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain_core.runnables.passthrough import RunnablePassthrough
|
||||
from langchain_core.vectorstores import VectorStoreRetriever
|
||||
from langchain_ollama import ChatOllama
|
||||
|
@ -19,15 +16,7 @@ from .configs import OllamaConfig
|
|||
|
||||
@dataclass(frozen=True)
|
||||
class ChatClient:
|
||||
chain: RunnableWithMessageHistory
|
||||
message_history: dict[str, BaseChatMessageHistory]
|
||||
|
||||
def get_last_response_or_none(self, session_id: str = "default"):
|
||||
try:
|
||||
messages = self.message_history[session_id].messages
|
||||
return messages[-1] if messages else None
|
||||
except (KeyError, IndexError):
|
||||
return None
|
||||
chain: RunnableSerializable
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
|
@ -43,7 +32,6 @@ def create_chat_chain(
|
|||
prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
("system", config.system_prompt),
|
||||
MessagesPlaceholder("history"),
|
||||
("human", "Context\n{context}\n\nQuestion:\n{input}"),
|
||||
]
|
||||
)
|
||||
|
@ -61,29 +49,11 @@ def create_chat_chain(
|
|||
)
|
||||
|
||||
|
||||
def get_session_history(
|
||||
session_id: str, message_history: dict[str, BaseChatMessageHistory]
|
||||
) -> BaseChatMessageHistory:
|
||||
"""Get or create chat message history for a session."""
|
||||
if session_id not in message_history:
|
||||
message_history[session_id] = ChatMessageHistory()
|
||||
return message_history[session_id]
|
||||
|
||||
|
||||
def create_chat_client(config: OllamaConfig):
|
||||
base_chain = create_chat_chain(config)
|
||||
message_history = {}
|
||||
|
||||
chain_with_history = RunnableWithMessageHistory(
|
||||
base_chain,
|
||||
lambda session_id: get_session_history(session_id, message_history),
|
||||
input_messages_key="input",
|
||||
history_messages_key="history",
|
||||
)
|
||||
|
||||
return ChatClient(
|
||||
chain=chain_with_history,
|
||||
message_history=message_history,
|
||||
chain=base_chain,
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue