-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembed.py
More file actions
67 lines (56 loc) · 2.12 KB
/
Copy pathembed.py
File metadata and controls
67 lines (56 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import os
import openai
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from dataclasses import dataclass
from tenacity import retry, wait_random_exponential, stop_after_attempt
load_dotenv()
openai.api_key = os.environ['OPENAI_KEY']
@dataclass
class Embeddings:
username: str
raw: pd.DataFrame
def load_embeddings(self, refresh = False):
try:
assert not refresh
saved = np.load(f'{self.username}.npz')
self.embeddings = pd.DataFrame(saved['emb'], index = saved['idx'])
except:
self.embeddings = pd.DataFrame()
self.save_embeddings()
def save_embeddings(self):
self.embeddings.dropna(inplace = True)
np.savez(self.username,
idx = self.embeddings.index.values,
emb = self.embeddings.values
)
def read(self, tweet):
idx = self.raw.reset_index().set_index('text').id.loc[tweet]
return self.embeddings.loc[idx].values
@retry(
wait = wait_random_exponential(min = 1, max = 20),
stop = stop_after_attempt(6)
)
def embed(self, tweets):
return np.stack(pd.DataFrame(openai.Embedding.create(
input = tweets,
model = 'text-embedding-ada-002'
).data).set_index('index').sort_index().embedding.values)
def run(self, batch_size = 200, checkpoint = 5):
"""
For larger jobs, see:
https://github.com/openai/openai-cookbook/blob/main/examples/api_request_parallel_processor.py
"""
self.load_embeddings()
new = self.raw.loc[self.raw.index.difference(self.embeddings.index)]
batches = np.array_split(new, new.shape[0] // batch_size + 1)
for j, batch in enumerate(batches):
self.embeddings = pd.concat([
self.embeddings,
pd.DataFrame(self.embed(list(batch.text)), index = batch.index)
])
if (j + 1) % checkpoint == 0:
self.save_embeddings()
self.embeddings = self.embeddings.loc[self.raw.index]
self.save_embeddings()