1414from semhash .datamodels import DeduplicationResult , DuplicateRecord , FilterResult , Record
1515from semhash .index import Index
1616from semhash .records import add_scores_to_records , map_deduplication_result_to_strings
17- from semhash .utils import Encoder , compute_candidate_limit , to_frozendict
17+ from semhash .utils import (
18+ Encoder ,
19+ compute_candidate_limit ,
20+ featurize ,
21+ prepare_records ,
22+ remove_exact_duplicates ,
23+ to_frozendict ,
24+ )
1825
1926
2027class SemHash (Generic [Record ]):
@@ -33,95 +40,6 @@ def __init__(self, index: Index, model: Encoder, columns: Sequence[str], was_str
3340 self ._was_string = was_string
3441 self ._ranking_cache : FilterResult | None = None
3542
36- @staticmethod
37- def _featurize (
38- records : Sequence [dict [str , str ]],
39- columns : Sequence [str ],
40- model : Encoder ,
41- ) -> np .ndarray :
42- """
43- Featurize a list of records using the model.
44-
45- :param records: A list of records.
46- :param columns: Columns to featurize.
47- :param model: An Encoder model.
48- :return: The embeddings of the records.
49- """
50- # Extract the embeddings for each column across all records
51- embeddings_per_col = []
52- for col in columns :
53- col_texts = [r [col ] for r in records ]
54- col_emb = model .encode (col_texts )
55- embeddings_per_col .append (np .asarray (col_emb ))
56-
57- return np .concatenate (embeddings_per_col , axis = 1 )
58-
59- @classmethod
60- def _remove_exact_duplicates (
61- cls ,
62- records : Sequence [dict [str , str ]],
63- columns : Sequence [str ],
64- reference_records : list [list [dict [str , str ]]] | None = None ,
65- ) -> tuple [list [dict [str , str ]], list [tuple [dict [str , str ], list [dict [str , str ]]]]]:
66- """
67- Remove exact duplicates based on the unpacked string representation of each record.
68-
69- If reference_records is None, the function will only check for duplicates within the records list.
70-
71- :param records: A list of records to check for exact duplicates.
72- :param columns: Columns to unpack.
73- :param reference_records: A list of records to compare against. These are already unpacked
74- :return: A list of deduplicated records and a list of duplicates.
75- """
76- deduplicated = []
77- duplicates = []
78-
79- column_set = set (columns )
80- # Build a seen set from reference_records if provided
81- seen : defaultdict [frozendict [str , str ], list [dict [str , str ]]] = defaultdict (list )
82- if reference_records is not None :
83- for record_set in reference_records :
84- key = to_frozendict (record_set [0 ], column_set )
85- seen [key ] = list (record_set )
86- in_one_set = reference_records is None
87-
88- for record in records :
89- frozen_record = frozendict ({k : v for k , v in record .items () if k in column_set })
90- if duplicated_records := seen .get (frozen_record ):
91- duplicates .append ((record , duplicated_records ))
92- else :
93- deduplicated .append (record )
94- # Only add current documents to seen if no reference set is used
95- if in_one_set :
96- seen [frozen_record ].append (record )
97-
98- return deduplicated , duplicates
99-
100- @staticmethod
101- def _prepare_records (
102- records : Sequence [Record ], columns : Sequence [str ] | None
103- ) -> tuple [list [dict [str , str ]], Sequence [str ], bool ]:
104- """
105- Validate and prepare records for processing.
106-
107- :param records: A list of records (strings or dictionaries).
108- :param columns: Columns to use if records are dictionaries.
109- :return: Tuple of (dict_records, columns, was_string).
110- :raises ValueError: If columns are not provided for dictionary records.
111- """
112- if columns is None and isinstance (records [0 ], dict ):
113- raise ValueError ("Columns must be specified when passing dictionaries." )
114-
115- if isinstance (records [0 ], str ):
116- columns = ["text" ]
117- dict_records : list [dict [str , str ]] = [{"text" : str (record )} for record in records ]
118- was_string = True
119- else :
120- dict_records = list (records )
121- was_string = False
122-
123- return dict_records , columns , was_string
124-
12543 @classmethod
12644 def from_embeddings (
12745 cls ,
@@ -152,10 +70,10 @@ def from_embeddings(
15270 raise ValueError (f"Number of embeddings ({ len (embeddings )} ) must match number of records ({ len (records )} )" )
15371
15472 # Prepare and validate records
155- dict_records , columns , was_string = cls . _prepare_records (records , columns )
73+ dict_records , columns , was_string = prepare_records (records , columns )
15674
15775 # Remove exact duplicates
158- deduplicated_records , exact_duplicates = cls . _remove_exact_duplicates (dict_records , columns )
76+ deduplicated_records , exact_duplicates = remove_exact_duplicates (dict_records , columns )
15977
16078 # Build items list. Each item is a list of exact duplicates
16179 items : list [list [dict [str , str ]]] = [[record ] for record in deduplicated_records ]
@@ -208,14 +126,14 @@ def from_records(
208126 :return: A SemHash instance with a fitted vicinity index.
209127 """
210128 # Prepare and validate records
211- dict_records , columns , was_string = cls . _prepare_records (records , columns )
129+ dict_records , columns , was_string = prepare_records (records , columns )
212130
213131 # If no model is provided, load the default model
214132 if model is None :
215133 model = StaticModel .from_pretrained ("minishlab/potion-base-8M" )
216134
217135 # Remove exact duplicates
218- deduplicated_records , duplicates = cls . _remove_exact_duplicates (dict_records , columns )
136+ deduplicated_records , duplicates = remove_exact_duplicates (dict_records , columns )
219137
220138 col_set = set (columns )
221139 duplicate_map = defaultdict (list )
@@ -231,7 +149,7 @@ def from_records(
231149 items .append (i )
232150
233151 # Create embeddings for deduplicated records only
234- embeddings = cls . _featurize (deduplicated_records , columns , model )
152+ embeddings = featurize (deduplicated_records , columns , model )
235153
236154 # Build the Vicinity index
237155 backend = ann_backend if use_ann else Backend .BASIC
@@ -263,7 +181,7 @@ def deduplicate(
263181 dict_records = self ._validate_if_strings (records )
264182
265183 # Remove exact duplicates before embedding
266- dict_records , exact_duplicates = self . _remove_exact_duplicates (
184+ dict_records , exact_duplicates = remove_exact_duplicates (
267185 records = dict_records , columns = self .columns , reference_records = self .index .items
268186 )
269187 duplicate_records = []
@@ -279,7 +197,7 @@ def deduplicate(
279197 )
280198
281199 # Compute embeddings for the new records
282- embeddings = self . _featurize (records = dict_records , columns = self .columns , model = self .model )
200+ embeddings = featurize (records = dict_records , columns = self .columns , model = self .model )
283201 # Query the fitted index
284202 results = self .index .query_threshold (embeddings , threshold = threshold )
285203
@@ -536,7 +454,7 @@ def _rank_by_average_similarity(
536454 :return: A FilterResult containing the ranking (records sorted and their average similarity scores).
537455 """
538456 dict_records = self ._validate_if_strings (records )
539- embeddings = self . _featurize (records = dict_records , columns = self .columns , model = self .model )
457+ embeddings = featurize (records = dict_records , columns = self .columns , model = self .model )
540458 results = self .index .query_top_k (embeddings , k = 100 , vectors_are_in_index = False )
541459
542460 # Compute the average similarity for each record.
@@ -600,7 +518,7 @@ def _diversify(
600518 if not candidates :
601519 return FilterResult (selected = [], filtered = [], scores_selected = [], scores_filtered = [])
602520
603- embeddings = self . _featurize (records = candidates , columns = self .columns , model = self .model )
521+ embeddings = featurize (records = candidates , columns = self .columns , model = self .model )
604522 result = diversify (
605523 embeddings = embeddings ,
606524 scores = np .array (relevance ),
0 commit comments