-
Notifications
You must be signed in to change notification settings - Fork 496
Expand file tree
/
Copy pathmain.py
More file actions
145 lines (127 loc) · 4.53 KB
/
main.py
File metadata and controls
145 lines (127 loc) · 4.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from dotenv import load_dotenv
from psycopg_pool import ConnectionPool
from pgvector.psycopg import register_vector
import cocoindex
import os
import functools
from numpy.typing import NDArray
import numpy as np
from datetime import timedelta
@cocoindex.transform_flow()
def text_to_embedding(
text: cocoindex.DataSlice[str],
) -> cocoindex.DataSlice[NDArray[np.float32]]:
"""
Embed the text using a SentenceTransformer model.
This is a shared logic between indexing and querying, so extract it as a function."""
# You can also switch to remote embedding model:
# return text.transform(
# cocoindex.functions.EmbedText(
# api_type=cocoindex.LlmApiType.OPENAI,
# model="text-embedding-3-small",
# )
# )
return text.transform(
cocoindex.functions.SentenceTransformerEmbed(
model="sentence-transformers/all-MiniLM-L6-v2"
)
)
@cocoindex.flow_def(name="TextEmbedding")
def text_embedding_flow(
flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope
) -> None:
"""
Define an example flow that embeds text into a vector database.
"""
data_scope["documents"] = flow_builder.add_source(
cocoindex.sources.LocalFile(path="markdown_files"),
refresh_interval=timedelta(seconds=5),
)
doc_embeddings = data_scope.add_collector()
with data_scope["documents"].row() as doc:
doc["chunks"] = doc["content"].transform(
cocoindex.functions.SplitRecursively(),
language="markdown",
chunk_size=2000,
chunk_overlap=500,
)
with doc["chunks"].row() as chunk:
chunk["embedding"] = text_to_embedding(chunk["text"])
doc_embeddings.collect(
filename=doc["filename"],
location=chunk["location"],
text=chunk["text"],
embedding=chunk["embedding"],
)
doc_embeddings.export(
"doc_embeddings",
cocoindex.targets.Postgres(),
primary_key_fields=["filename", "location"],
vector_indexes=[
cocoindex.VectorIndexDef(
field_name="embedding",
metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
)
],
)
@functools.cache
def connection_pool() -> ConnectionPool:
"""
Get a connection pool to the database.
"""
return ConnectionPool(os.environ["COCOINDEX_DATABASE_URL"])
TOP_K = 5
# Declaring it as a query handler, so that you can easily run queries in CocoInsight.
@text_embedding_flow.query_handler(
result_fields=cocoindex.QueryHandlerResultFields(
embedding=["embedding"],
score="score",
),
)
def search(query: str) -> cocoindex.QueryOutput:
# Get the table name, for the export target in the text_embedding_flow above.
table_name = cocoindex.utils.get_target_default_name(
text_embedding_flow, "doc_embeddings"
)
# Evaluate the transform flow defined above with the input query, to get the embedding.
query_vector = text_to_embedding.eval(query)
# Run the query and get the results.
with connection_pool().connection() as conn:
register_vector(conn)
with conn.cursor() as cur:
cur.execute(
f"""
SELECT filename, text, embedding <=> %s AS distance
FROM {table_name} ORDER BY distance LIMIT %s
""",
(query_vector, TOP_K),
)
results = [
{"filename": row[0], "text": row[1], "score": 1.0 - row[2]}
for row in cur.fetchall()
]
return cocoindex.QueryOutput(
results=results,
query_info=cocoindex.QueryInfo(
embedding=query_vector,
similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY,
),
)
def _main() -> None:
# Run queries in a loop to demonstrate the query capabilities.
while True:
query = input("Enter search query (or Enter to quit): ")
if query == "":
break
# Run the query function with the database connection pool and the query.
query_output = search(query)
print("\nSearch results:")
for result in query_output.results:
print(f"[{result['score']:.3f}] {result['filename']}")
print(f" {result['text']}")
print("---")
print()
if __name__ == "__main__":
load_dotenv()
cocoindex.init()
_main()