Skip to content

Commit 5ed061d

Browse files
committed
Refactored SemHash, moved more functions to utils
1 parent ef50401 commit 5ed061d

1 file changed

Lines changed: 62 additions & 60 deletions

File tree

semhash/semhash.py

Lines changed: 62 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,66 @@ def __init__(self, index: Index, model: Encoder, columns: Sequence[str], was_str
4040
self._was_string = was_string
4141
self._ranking_cache: FilterResult | None = None
4242

43+
@classmethod
44+
def from_records(
45+
cls,
46+
records: Sequence[Record],
47+
columns: Sequence[str] | None = None,
48+
use_ann: bool = True,
49+
model: Encoder | None = None,
50+
ann_backend: Backend | str = Backend.USEARCH,
51+
**kwargs: Any,
52+
) -> SemHash:
53+
"""
54+
Initialize a SemHash instance from records.
55+
56+
This removes exact duplicates, featurizes the records, and fits a vicinity index.
57+
58+
:param records: A list of records (strings or dictionaries).
59+
:param columns: Columns to featurize if records are dictionaries.
60+
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
61+
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
62+
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
63+
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
64+
:return: A SemHash instance with a fitted vicinity index.
65+
"""
66+
# Prepare and validate records
67+
dict_records, columns, was_string = prepare_records(records, columns)
68+
69+
# If no model is provided, load the default model
70+
if model is None:
71+
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
72+
73+
# Remove exact duplicates
74+
deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)
75+
76+
col_set = set(columns)
77+
duplicate_map = defaultdict(list)
78+
for x, _ in duplicates:
79+
frozen_record = to_frozendict(x, col_set)
80+
duplicate_map[frozen_record].append(x)
81+
82+
items: list[list[dict[str, str]]] = []
83+
for record in deduplicated_records:
84+
i = [record]
85+
frozen_record = to_frozendict(record, col_set)
86+
i.extend(duplicate_map[frozen_record])
87+
items.append(i)
88+
89+
# Create embeddings for deduplicated records only
90+
embeddings = featurize(deduplicated_records, columns, model)
91+
92+
# Build the Vicinity index
93+
backend = ann_backend if use_ann else Backend.BASIC
94+
index = Index.from_vectors_and_items(
95+
vectors=embeddings,
96+
items=items,
97+
backend_type=backend,
98+
**kwargs,
99+
)
100+
101+
return cls(index=index, columns=columns, model=model, was_string=was_string)
102+
43103
@classmethod
44104
def from_embeddings(
45105
cls,
@@ -54,6 +114,8 @@ def from_embeddings(
54114
"""
55115
Initialize a SemHash instance from pre-computed embeddings.
56116
117+
This removes exact duplicates and fits a vicinity index using the provided embeddings.
118+
57119
:param embeddings: Pre-computed embeddings as a numpy array of shape (n_records, embedding_dim).
58120
:param records: A list of records (strings or dictionaries) corresponding to the embeddings.
59121
:param model: The Encoder model used for creating the embeddings.
@@ -102,66 +164,6 @@ def from_embeddings(
102164

103165
return cls(index=index, model=model, columns=columns, was_string=was_string)
104166

105-
@classmethod
106-
def from_records(
107-
cls,
108-
records: Sequence[Record],
109-
columns: Sequence[str] | None = None,
110-
use_ann: bool = True,
111-
model: Encoder | None = None,
112-
ann_backend: Backend | str = Backend.USEARCH,
113-
**kwargs: Any,
114-
) -> SemHash:
115-
"""
116-
Initialize a SemHash instance from records.
117-
118-
This removes exact duplicates, featurizes the records, and fits a vicinity index.
119-
120-
:param records: A list of records (strings or dictionaries).
121-
:param columns: Columns to featurize if records are dictionaries.
122-
:param use_ann: Whether to use approximate nearest neighbors (True) or basic search (False). Default is True.
123-
:param model: (Optional) An Encoder model. If None, the default model is used (minishlab/potion-base-8M).
124-
:param ann_backend: (Optional) The ANN backend to use if use_ann is True. Defaults to Backend.USEARCH.
125-
:param **kwargs: Any additional keyword arguments to pass to the Vicinity index.
126-
:return: A SemHash instance with a fitted vicinity index.
127-
"""
128-
# Prepare and validate records
129-
dict_records, columns, was_string = prepare_records(records, columns)
130-
131-
# If no model is provided, load the default model
132-
if model is None:
133-
model = StaticModel.from_pretrained("minishlab/potion-base-8M")
134-
135-
# Remove exact duplicates
136-
deduplicated_records, duplicates = remove_exact_duplicates(dict_records, columns)
137-
138-
col_set = set(columns)
139-
duplicate_map = defaultdict(list)
140-
for x, _ in duplicates:
141-
frozen_record = to_frozendict(x, col_set)
142-
duplicate_map[frozen_record].append(x)
143-
144-
items: list[list[dict[str, str]]] = []
145-
for record in deduplicated_records:
146-
i = [record]
147-
frozen_record = to_frozendict(record, col_set)
148-
i.extend(duplicate_map[frozen_record])
149-
items.append(i)
150-
151-
# Create embeddings for deduplicated records only
152-
embeddings = featurize(deduplicated_records, columns, model)
153-
154-
# Build the Vicinity index
155-
backend = ann_backend if use_ann else Backend.BASIC
156-
index = Index.from_vectors_and_items(
157-
vectors=embeddings,
158-
items=items,
159-
backend_type=backend,
160-
**kwargs,
161-
)
162-
163-
return cls(index=index, columns=columns, model=model, was_string=was_string)
164-
165167
def deduplicate(
166168
self,
167169
records: Sequence[Record],

0 commit comments

Comments
 (0)