|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
6 | 6 | import torch |
7 | | -from pgvector.psycopg import Bit |
8 | 7 |
|
9 | 8 | from core.models.chunk import DocumentChunk |
10 | 9 | from core.tests import setup_test_logging |
11 | | -from core.vector_store.multi_vector_store import MultiVectorStore |
| 10 | +from core.vector_store.milvus_multivector_store import MilvusMultiVectorStore |
12 | 11 |
|
13 | 12 | # Set up test logging |
14 | 13 | setup_test_logging() |
@@ -50,102 +49,115 @@ def get_sample_document_chunks(num_chunks=3, num_vectors=3, dim=128): |
50 | 49 | return chunks |
51 | 50 |
|
52 | 51 |
|
53 | | -# Fixtures |
| 52 | +# For Milvus |
54 | 53 | @pytest.fixture(scope="function") |
55 | 54 | async def vector_store(): |
56 | 55 | """Create a real MultiVectorStore instance connected to the test database""" |
57 | | - # Create the store |
58 | | - store = MultiVectorStore(uri=TEST_DB_URI) |
59 | | - |
60 | | - try: |
61 | | - # Try to initialize the database |
62 | | - store.initialize() |
63 | | - |
64 | | - # Clean up any existing data |
65 | | - store.conn.execute("TRUNCATE TABLE multi_vector_embeddings RESTART IDENTITY") |
66 | | - |
67 | | - # Drop the function if it exists |
68 | | - try: |
69 | | - store.conn.execute("DROP FUNCTION IF EXISTS max_sim(bit[], bit[])") |
70 | | - except Exception as e: |
71 | | - print(f"Error dropping function: {e}") |
72 | | - except Exception as e: |
73 | | - print(f"Error setting up database: {e}") |
74 | | - |
| 56 | + store = MilvusMultiVectorStore(collection_name="test_collection") |
| 57 | + store.client.drop_collection(collection_name="test_collection") |
| 58 | + store._create_collection() |
| 59 | + store.client.load_collection(collection_name="test_collection") |
75 | 60 | yield store |
76 | | - |
77 | | - # Clean up after tests |
78 | | - try: |
79 | | - store.conn.execute("TRUNCATE TABLE multi_vector_embeddings RESTART IDENTITY") |
80 | | - except Exception as e: |
81 | | - print(f"Error cleaning up: {e}") |
82 | | - |
83 | | - # Close connection |
84 | 61 | store.close() |
85 | 62 |
|
86 | 63 |
|
87 | | -# Glassbox Tests - Testing internal implementation details |
88 | | -@pytest.mark.asyncio |
89 | | -async def test_binary_quantize(): |
90 | | - """Test the _binary_quantize method correctly converts embeddings""" |
91 | | - store = MultiVectorStore(uri=TEST_DB_URI) |
92 | | - |
93 | | - # Test with torch tensor |
94 | | - torch_embeddings = torch.tensor([[0.1, -0.2, 0.3], [-0.1, 0.2, -0.3]]) |
95 | | - binary_result = store._binary_quantize(torch_embeddings) |
96 | | - assert len(binary_result) == 2 |
97 | | - |
98 | | - # Check results match expected patterns |
99 | | - assert binary_result[0].to_text() == Bit("101").to_text() # Positive values (>0) become 1, negative/zero become 0 |
100 | | - assert binary_result[1].to_text() == Bit("010").to_text() # First row: [0.1 (>0), -0.2 (<0), 0.3 (>0)] → "101" |
101 | | - # Second row: [-0.1 (<0), 0.2 (>0), -0.3 (<0)] → "010" |
102 | | - |
103 | | - # Test with numpy array |
104 | | - numpy_embeddings = np.array([[0.1, -0.2, 0.3], [-0.1, 0.2, -0.3]]) |
105 | | - binary_result = store._binary_quantize(numpy_embeddings) |
106 | | - assert len(binary_result) == 2 |
107 | | - |
108 | | - assert binary_result[0].to_text() == Bit("101").to_text() |
109 | | - assert binary_result[1].to_text() == Bit("010").to_text() |
110 | | - |
111 | | - |
112 | | -@pytest.mark.asyncio |
113 | | -async def test_initialize_creates_tables_and_function(vector_store): |
114 | | - """Test that initialize creates the necessary tables and functions""" |
115 | | - vector_store.initialize() |
116 | | - # Check if the table exists |
117 | | - result = vector_store.conn.execute( |
118 | | - "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'multi_vector_embeddings')" |
119 | | - ).fetchone() |
120 | | - table_exists = result[0] |
121 | | - assert table_exists is True |
122 | | - |
123 | | - logger.info("Table exists!") |
124 | | - |
125 | | - # Check if the max_sim function exists |
126 | | - result = vector_store.conn.execute("SELECT EXISTS (SELECT FROM pg_proc WHERE proname = 'max_sim')").fetchone() |
127 | | - function_exists = result[0] |
128 | | - logger.info(f"Function exists {function_exists}") |
129 | | - assert function_exists is True |
130 | | - |
131 | | - |
132 | | -@pytest.mark.asyncio |
133 | | -async def test_database_schema(vector_store): |
134 | | - """Test that the database schema matches our expectations""" |
135 | | - # Check columns in the table |
136 | | - result = vector_store.conn.execute( |
137 | | - "SELECT column_name, data_type FROM information_schema.columns " "WHERE table_name = 'multi_vector_embeddings'" |
138 | | - ).fetchall() |
139 | | - |
140 | | - # Convert to a dict for easier checking |
141 | | - column_dict = {col[0]: col[1] for col in result} |
142 | | - |
143 | | - # Check required columns |
144 | | - assert "id" in column_dict |
145 | | - assert "document_id" in column_dict |
146 | | - assert "chunk_number" in column_dict |
147 | | - assert "content" in column_dict |
148 | | - assert "embeddings" in column_dict |
| 64 | +# For Postgres |
| 65 | +# # Fixtures |
| 66 | +# @pytest.fixture(scope="function") |
| 67 | +# async def vector_store(): |
| 68 | +# """Create a real MultiVectorStore instance connected to the test database""" |
| 69 | +# # Create the store |
| 70 | +# store = MultiVectorStore(uri=TEST_DB_URI) |
| 71 | + |
| 72 | +# try: |
| 73 | +# # Try to initialize the database |
| 74 | +# store.initialize() |
| 75 | + |
| 76 | +# # Clean up any existing data |
| 77 | +# store.conn.execute("TRUNCATE TABLE multi_vector_embeddings RESTART IDENTITY") |
| 78 | + |
| 79 | +# # Drop the function if it exists |
| 80 | +# try: |
| 81 | +# store.conn.execute("DROP FUNCTION IF EXISTS max_sim(bit[], bit[])") |
| 82 | +# except Exception as e: |
| 83 | +# print(f"Error dropping function: {e}") |
| 84 | +# except Exception as e: |
| 85 | +# print(f"Error setting up database: {e}") |
| 86 | + |
| 87 | +# yield store |
| 88 | + |
| 89 | +# # Clean up after tests |
| 90 | +# try: |
| 91 | +# store.conn.execute("TRUNCATE TABLE multi_vector_embeddings RESTART IDENTITY") |
| 92 | +# except Exception as e: |
| 93 | +# print(f"Error cleaning up: {e}") |
| 94 | + |
| 95 | +# # Close connection |
| 96 | +# store.close() |
| 97 | + |
| 98 | + |
| 99 | +# # Glassbox Tests - Testing internal implementation details |
| 100 | +# @pytest.mark.asyncio |
| 101 | +# async def test_binary_quantize(): |
| 102 | +# """Test the _binary_quantize method correctly converts embeddings""" |
| 103 | +# store = MultiVectorStore(uri=TEST_DB_URI) |
| 104 | + |
| 105 | +# # Test with torch tensor |
| 106 | +# torch_embeddings = torch.tensor([[0.1, -0.2, 0.3], [-0.1, 0.2, -0.3]]) |
| 107 | +# binary_result = store._binary_quantize(torch_embeddings) |
| 108 | +# assert len(binary_result) == 2 |
| 109 | + |
| 110 | +# # Check results match expected patterns |
| 111 | +# assert binary_result[0].to_text() == Bit("101").to_text() # Positive values (>0) become 1, negative/zero become 0 |
| 112 | +# assert binary_result[1].to_text() == Bit("010").to_text() # First row: [0.1 (>0), -0.2 (<0), 0.3 (>0)] → "101" |
| 113 | +# # Second row: [-0.1 (<0), 0.2 (>0), -0.3 (<0)] → "010" |
| 114 | + |
| 115 | +# # Test with numpy array |
| 116 | +# numpy_embeddings = np.array([[0.1, -0.2, 0.3], [-0.1, 0.2, -0.3]]) |
| 117 | +# binary_result = store._binary_quantize(numpy_embeddings) |
| 118 | +# assert len(binary_result) == 2 |
| 119 | + |
| 120 | +# assert binary_result[0].to_text() == Bit("101").to_text() |
| 121 | +# assert binary_result[1].to_text() == Bit("010").to_text() |
| 122 | + |
| 123 | + |
| 124 | +# @pytest.mark.asyncio |
| 125 | +# async def test_initialize_creates_tables_and_function(vector_store): |
| 126 | +# """Test that initialize creates the necessary tables and functions""" |
| 127 | +# vector_store.initialize() |
| 128 | +# # Check if the table exists |
| 129 | +# result = vector_store.conn.execute( |
| 130 | +# "SELECT EXISTS (SELECT FROM information_schema.tables WHERE table_name = 'multi_vector_embeddings')" |
| 131 | +# ).fetchone() |
| 132 | +# table_exists = result[0] |
| 133 | +# assert table_exists is True |
| 134 | + |
| 135 | +# logger.info("Table exists!") |
| 136 | + |
| 137 | +# # Check if the max_sim function exists |
| 138 | +# result = vector_store.conn.execute("SELECT EXISTS (SELECT FROM pg_proc WHERE proname = 'max_sim')").fetchone() |
| 139 | +# function_exists = result[0] |
| 140 | +# logger.info(f"Function exists {function_exists}") |
| 141 | +# assert function_exists is True |
| 142 | + |
| 143 | + |
| 144 | +# @pytest.mark.asyncio |
| 145 | +# async def test_database_schema(vector_store): |
| 146 | +# """Test that the database schema matches our expectations""" |
| 147 | +# # Check columns in the table |
| 148 | +# result = vector_store.conn.execute( |
| 149 | +# "SELECT column_name, data_type FROM information_schema.columns " "WHERE table_name = 'multi_vector_embeddings'" |
| 150 | +# ).fetchall() |
| 151 | + |
| 152 | +# # Convert to a dict for easier checking |
| 153 | +# column_dict = {col[0]: col[1] for col in result} |
| 154 | + |
| 155 | +# # Check required columns |
| 156 | +# assert "id" in column_dict |
| 157 | +# assert "document_id" in column_dict |
| 158 | +# assert "chunk_number" in column_dict |
| 159 | +# assert "content" in column_dict |
| 160 | +# assert "embeddings" in column_dict |
149 | 161 |
|
150 | 162 |
|
151 | 163 | # Blackbox Tests - Testing the public API |
|
0 commit comments