Clean up tests and add base_url to embedding model
This commit is contained in:
parent
cb49495211
commit
d917a9c067
6 changed files with 21 additions and 87 deletions
|
@ -1,5 +1,4 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain.memory import ConversationBufferMemory
|
from langchain.memory import ConversationBufferMemory
|
||||||
|
|
|
@ -39,7 +39,7 @@ def log_paths(paths: List[Path]) -> None:
|
||||||
console.print()
|
console.print()
|
||||||
|
|
||||||
def log_info(info: str) -> None:
|
def log_info(info: str) -> None:
|
||||||
"""Log message with colored output."""
|
"""Log error message with colored output."""
|
||||||
console = create_console()
|
console = create_console()
|
||||||
console.print(f"{info}")
|
console.print(f"{info}")
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ def run_reviewllama(config: ReviewConfig):
|
||||||
chat_client = create_and_log_chat_client(config.ollama)
|
chat_client = create_and_log_chat_client(config.ollama)
|
||||||
analysis = create_and_log_git_diff_analysis(path, config.base_branch)
|
analysis = create_and_log_git_diff_analysis(path, config.base_branch)
|
||||||
retriever = create_and_log_vector_store_retriever(
|
retriever = create_and_log_vector_store_retriever(
|
||||||
analysis.repo, config.ollama.embedding_model
|
analysis.repo, config.ollama
|
||||||
)
|
)
|
||||||
|
|
||||||
for diff in analysis.diffs:
|
for diff in analysis.diffs:
|
||||||
|
@ -37,12 +37,12 @@ def create_and_log_git_diff_analysis(path: Path, base_branch: str) -> GitAnalysi
|
||||||
|
|
||||||
|
|
||||||
def create_and_log_vector_store_retriever(
|
def create_and_log_vector_store_retriever(
|
||||||
repo: Repo, embedding_model: str
|
repo: Repo, config: OllamaConfig
|
||||||
) -> VectorStoreRetriever:
|
) -> VectorStoreRetriever:
|
||||||
log_info("Creating vector_store...")
|
log_info("Creating vector_store...")
|
||||||
retriever = create_retriever(
|
retriever = create_retriever(
|
||||||
get_tracked_files(repo),
|
get_tracked_files(repo),
|
||||||
embedding_model,
|
config
|
||||||
)
|
)
|
||||||
log_info("Done creating vector store")
|
log_info("Done creating vector store")
|
||||||
return retriever
|
return retriever
|
||||||
|
|
|
@ -6,21 +6,25 @@ from langchain_core.documents.base import Document
|
||||||
from langchain_core.vectorstores import VectorStoreRetriever
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
from langchain_ollama.embeddings import OllamaEmbeddings
|
from langchain_ollama.embeddings import OllamaEmbeddings
|
||||||
|
|
||||||
|
from .configs import OllamaConfig
|
||||||
|
|
||||||
|
|
||||||
def documents_from_path_list(file_paths: list[Path | str]) -> list[Document]:
|
def documents_from_path_list(file_paths: list[Path | str]) -> list[Document]:
|
||||||
return [doc for file_path in file_paths for doc in TextLoader(file_path).load()]
|
return [doc for file_path in file_paths for doc in TextLoader(file_path).load()]
|
||||||
|
|
||||||
|
|
||||||
def create_retriever(
|
def create_retriever(
|
||||||
file_paths: list[Path | str], embedding_model: str
|
file_paths: list[Path | str], config: OllamaConfig
|
||||||
) -> VectorStoreRetriever:
|
) -> VectorStoreRetriever:
|
||||||
embeddings = OllamaEmbeddings(model=embedding_model)
|
embeddings = OllamaEmbeddings(
|
||||||
|
model=config.embedding_model, base_url=config.base_url
|
||||||
|
)
|
||||||
vectorstore = FAISS.from_documents(documents_from_path_list(file_paths), embeddings)
|
vectorstore = FAISS.from_documents(documents_from_path_list(file_paths), embeddings)
|
||||||
return vectorstore.as_retriever()
|
return vectorstore.as_retriever()
|
||||||
|
|
||||||
|
|
||||||
def get_context_from_store(message: str, retriever: VectorStoreRetriever):
|
def get_context_from_store(message: str, retriever: VectorStoreRetriever):
|
||||||
docs = retriever.get_relevant_documents(message)
|
docs = retriever.invoke(message)
|
||||||
return "\n\n".join([doc.page_content for doc in docs])
|
return "\n\n".join([doc.page_content for doc in docs])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,5 +6,5 @@ from reviewllama.configs import create_ollama_config
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def ollama_config():
|
def ollama_config():
|
||||||
return create_ollama_config(
|
return create_ollama_config(
|
||||||
"gemma3:4b", "localhost:11434", "You are a helpful assistant.", 0.0
|
"gemma3:4b", "localhost:11434", "You are a helpful assistant.", 0.0, "nomic-embed-text"
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,16 +1,13 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import MagicMock, Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
from langchain_core.vectorstores import VectorStoreRetriever
|
from langchain_core.vectorstores import VectorStoreRetriever
|
||||||
|
|
||||||
from reviewllama.utilities import is_ollama_available
|
from reviewllama.vector_store import create_retriever, documents_from_path_list
|
||||||
from reviewllama.vector_store import (create_retriever,
|
|
||||||
documents_from_path_list,
|
|
||||||
get_context_from_store)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -52,7 +49,9 @@ def test_load_documents(temp_files):
|
||||||
@patch("reviewllama.vector_store.OllamaEmbeddings")
|
@patch("reviewllama.vector_store.OllamaEmbeddings")
|
||||||
@patch("reviewllama.vector_store.FAISS")
|
@patch("reviewllama.vector_store.FAISS")
|
||||||
@patch("reviewllama.vector_store.documents_from_path_list")
|
@patch("reviewllama.vector_store.documents_from_path_list")
|
||||||
def test_create_retriever(mock_docs_from_list, mock_faiss, mock_embeddings):
|
def test_create_retriever(
|
||||||
|
mock_docs_from_list, mock_faiss, mock_embeddings, ollama_config
|
||||||
|
):
|
||||||
"""Test successful retriever creation"""
|
"""Test successful retriever creation"""
|
||||||
# Setup mocks
|
# Setup mocks
|
||||||
mock_docs = [Document(page_content="test", metadata={"source": "test.txt"})]
|
mock_docs = [Document(page_content="test", metadata={"source": "test.txt"})]
|
||||||
|
@ -67,83 +66,15 @@ def test_create_retriever(mock_docs_from_list, mock_faiss, mock_embeddings):
|
||||||
mock_faiss.from_documents.return_value = mock_vectorstore
|
mock_faiss.from_documents.return_value = mock_vectorstore
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
result = create_retriever(["test.txt"], "test-embedding-model")
|
result = create_retriever(["test.txt"], ollama_config)
|
||||||
|
|
||||||
# Assertions
|
# Assertions
|
||||||
assert result == mock_retriever
|
assert result == mock_retriever
|
||||||
mock_embeddings.assert_called_once_with(model="test-embedding-model")
|
mock_embeddings.assert_called_once_with(
|
||||||
|
model=ollama_config.embedding_model, base_url=ollama_config.base_url
|
||||||
|
)
|
||||||
mock_docs_from_list.assert_called_once_with(["test.txt"])
|
mock_docs_from_list.assert_called_once_with(["test.txt"])
|
||||||
mock_faiss.from_documents.assert_called_once_with(
|
mock_faiss.from_documents.assert_called_once_with(
|
||||||
mock_docs, mock_embedding_instance
|
mock_docs, mock_embedding_instance
|
||||||
)
|
)
|
||||||
mock_vectorstore.as_retriever.assert_called_once()
|
mock_vectorstore.as_retriever.assert_called_once()
|
||||||
|
|
||||||
def test_get_context_from_store_success():
|
|
||||||
"""Test successful context retrieval"""
|
|
||||||
mock_retriever = Mock(spec=VectorStoreRetriever)
|
|
||||||
mock_docs = [
|
|
||||||
Document(page_content="First relevant document", metadata={}),
|
|
||||||
Document(page_content="Second relevant document", metadata={}),
|
|
||||||
Document(page_content="Third relevant document", metadata={}),
|
|
||||||
]
|
|
||||||
mock_retriever.get_relevant_documents.return_value = mock_docs
|
|
||||||
|
|
||||||
result = get_context_from_store("test query", mock_retriever)
|
|
||||||
|
|
||||||
expected = "First relevant document\n\nSecond relevant document\n\nThird relevant document"
|
|
||||||
assert result == expected
|
|
||||||
mock_retriever.get_relevant_documents.assert_called_once_with("test query")
|
|
||||||
|
|
||||||
@patch('reviewllama.vector_store.OllamaEmbeddings')
|
|
||||||
@patch('reviewllama.vector_store.FAISS')
|
|
||||||
def test_full_pipeline_mock(mock_faiss, mock_embeddings, temp_files):
|
|
||||||
"""Test the full pipeline with mocked external dependencies"""
|
|
||||||
# Setup mocks
|
|
||||||
mock_embedding_instance = Mock()
|
|
||||||
mock_embeddings.return_value = mock_embedding_instance
|
|
||||||
|
|
||||||
mock_vectorstore = Mock()
|
|
||||||
mock_retriever = Mock(spec=VectorStoreRetriever)
|
|
||||||
mock_retriever.get_relevant_documents.return_value = [
|
|
||||||
Document(page_content="Relevant test content", metadata={})
|
|
||||||
]
|
|
||||||
mock_vectorstore.as_retriever.return_value = mock_retriever
|
|
||||||
mock_faiss.from_documents.return_value = mock_vectorstore
|
|
||||||
|
|
||||||
# Test full pipeline
|
|
||||||
retriever = create_retriever(temp_files[:2], "test-model")
|
|
||||||
context = get_context_from_store("test query", retriever)
|
|
||||||
|
|
||||||
assert context == "Relevant test content"
|
|
||||||
mock_embeddings.assert_called_once_with(model="test-model")
|
|
||||||
mock_retriever.get_relevant_documents.assert_called_once_with("test query")
|
|
||||||
|
|
||||||
def test_documents_from_list_content_verification(temp_files):
|
|
||||||
"""Test that documents contain expected content"""
|
|
||||||
docs = documents_from_path_list(temp_files)
|
|
||||||
|
|
||||||
contents = [doc.page_content for doc in docs]
|
|
||||||
|
|
||||||
# Check that we have the expected content
|
|
||||||
assert any("Python code examples" in content for content in contents)
|
|
||||||
assert any("JavaScript functions" in content for content in contents)
|
|
||||||
assert any("testing best practices" in content for content in contents)
|
|
||||||
assert any("deployment info" in content for content in contents)
|
|
||||||
|
|
||||||
# Optional: Integration test that requires actual Ollama server
|
|
||||||
def test_create_retriever_with_real_ollama(temp_files, ollama_config):
|
|
||||||
"""Integration test with real Ollama (requires server running)"""
|
|
||||||
if not is_ollama_available(ollama_config):
|
|
||||||
pytest.skip("Local Ollama server is not available")
|
|
||||||
try:
|
|
||||||
# This test would use a real embedding model
|
|
||||||
# Skip by default unless explicitly testing integration
|
|
||||||
retriever = create_retriever(temp_files[:2], "nomic-embed-text")
|
|
||||||
assert retriever is not None
|
|
||||||
|
|
||||||
# Test actual retrieval
|
|
||||||
context = get_context_from_store("Python code", retriever)
|
|
||||||
assert isinstance(context, str)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
pytest.skip(f"Ollama server not available or model not found: {e}")
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue