-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvector_db.py
More file actions
35 lines (31 loc) · 1.35 KB
/
Copy pathvector_db.py
File metadata and controls
35 lines (31 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
class QdrantStorage:
def __init__(self,url="http://localhost:6333", collection="docs",dim=3072):
self.client = QdrantClient(url=url,timeout=30)
self.collection = collection
if not self.client.collection_exists(self.collection):
self.client.create_collection(
collection_name = self.collection,
vectors_config = VectorParams(size=dim,distance=Distance.COSINE)
)
def upsert(self,ids,vectors,payloads):
points = [PointStruct(id=ids[i], vector = vectors[i], payload=payloads[i]) for i in range(len(ids))]
self.client.upsert(self.collection, points=points)
def search(self,query_vector,top_k:int=5):
results = self.client.search(
collection_name=self.collection,
query_vector=query_vector,
with_payload=True,
limit=top_k
)
contexts = []
sources = set()
for r in results:
payload = getattr(r, 'payload', None) or {}
text = payload.get('text', '')
source = payload.get('source', '')
if text:
contexts.append(text)
sources.add(source)
return {"contexts" : contexts, "sources":list(sources)}