1414
1515from model2vec .tokenizer .datamodels import Token
1616from model2vec .tokenizer .model import process_tokenizer
17- from model2vec .tokenizer .normalizer import prepare_normalizer
18- from model2vec .tokenizer .pretokenizer import fix_pretokenizer
17+ from model2vec .tokenizer .normalizer import replace_normalizer
18+ from model2vec .tokenizer .pretokenizer import replace_pretokenizer
1919
2020logger = logging .getLogger (__name__ )
2121
@@ -54,11 +54,7 @@ def replace_vocabulary(
5454 tokenizer : Tokenizer , new_vocabulary : list [Token ], unk_token : str | None , pad_token : str | None
5555) -> Tokenizer :
5656 """Replace the vocabulary of a tokenizer with a new one."""
57- tokenizer = tokenizer .from_str (tokenizer .to_str ())
58- tokenizer .normalizer = prepare_normalizer (tokenizer .normalizer ) # type: ignore[assignment] # Is just wrong
5957 tokenizer_json : dict [str , Any ] = json .loads (tokenizer .to_str ())
60- tokenizer_json ["pre_tokenizer" ] = fix_pretokenizer (tokenizer_json ["pre_tokenizer" ])
61-
6258 added_tokens : list [dict [str , Any ]] = tokenizer_json ["added_tokens" ]
6359
6460 pre_tokenized_tokens = [x .normalized_form for x in new_vocabulary ]
@@ -102,7 +98,7 @@ def clean_and_create_vocabulary(
10298 tokenizer : PreTrainedTokenizerFast ,
10399 vocabulary : list [str ],
104100 token_remove_regex : re .Pattern | None ,
105- ) -> list [Token ]:
101+ ) -> tuple [ list [Token ], Tokenizer ]:
106102 """Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
107103 seen_tokens = set ()
108104 post_normalize_seen_tokens = set ()
@@ -115,15 +111,12 @@ def clean_and_create_vocabulary(
115111 internal_vocab : dict [str , int ] = tokenizer .get_vocab ()
116112 internal_tokens : list [str ] = [k for k , _ in sorted (internal_vocab .items (), key = lambda x : x [1 ])]
117113
118- cleaned_vocabulary = _process_internal_tokens (tokenizer , internal_tokens , token_remove_regex )
119- internal_tokens_set = {token .form for token in cleaned_vocabulary }
120-
121- # Change the backend tokenizer to the new one.
114+ # Copy the backend tokenizer to avoid modifying the original.
122115 backend_tokenizer = backend_tokenizer .from_str (backend_tokenizer .to_str ())
123- backend_tokenizer . normalizer = prepare_normalizer (backend_tokenizer . normalizer ) # type: ignore[assignment] # Is just wrong
124- tokenizer_json : dict [ str , Any ] = json . loads ( backend_tokenizer . to_str ())
125- tokenizer_json [ "pre_tokenizer" ] = fix_pretokenizer ( tokenizer_json [ "pre_tokenizer" ] )
126- backend_tokenizer = Tokenizer . from_str ( json . dumps ( tokenizer_json ))
116+ backend_tokenizer = replace_normalizer (backend_tokenizer )
117+
118+ cleaned_vocabulary = _process_internal_tokens ( tokenizer , backend_tokenizer , internal_tokens , token_remove_regex )
119+ internal_tokens_set = { token . form for token in cleaned_vocabulary }
127120
128121 normalizer : Normalizer | None = backend_tokenizer .normalizer
129122 for token in vocabulary :
@@ -178,11 +171,14 @@ def clean_and_create_vocabulary(
178171 if n_empty :
179172 logger .warning (f"Removed { n_empty } empty tokens." )
180173
181- return cleaned_vocabulary
174+ return cleaned_vocabulary , replace_pretokenizer ( backend_tokenizer )
182175
183176
184177def _process_internal_tokens (
185- tokenizer : PreTrainedTokenizerFast , internal_tokens : list [str ], token_remove_regex : re .Pattern | None
178+ tokenizer : PreTrainedTokenizerFast ,
179+ backend_tokenizer : Tokenizer ,
180+ internal_tokens : list [str ],
181+ token_remove_regex : re .Pattern | None ,
186182) -> list [Token ]:
187183 """Clean internal tokens."""
188184 # Get the pad and unk token from the tokenizer.
@@ -193,7 +189,6 @@ def _process_internal_tokens(
193189 added_tokens_to_remove = set (tokenizer .added_tokens_encoder ) - added_tokens_to_keep
194190 cleaned_internal_tokens : list [Token ] = []
195191
196- backend_tokenizer = tokenizer .backend_tokenizer
197192 # Figure out whether token is a subword or not.
198193 encoded = backend_tokenizer .encode (f" { 'a' * 25 } " , add_special_tokens = False )
199194 first_token , second_token , * _ = encoded .tokens
@@ -378,7 +373,7 @@ def create_tokenizer(
378373 """
379374 unk_token = cast (str | None , tokenizer .special_tokens_map .get ("unk_token" ))
380375 pad_token = cast (str | None , tokenizer .special_tokens_map .get ("pad_token" ))
381- cleaned_vocabulary = clean_and_create_vocabulary (tokenizer , vocabulary , token_remove_regex )
382- new_tokenizer = replace_vocabulary (tokenizer . backend_tokenizer , cleaned_vocabulary , unk_token , pad_token )
376+ cleaned_vocabulary , backend_tokenizer = clean_and_create_vocabulary (tokenizer , vocabulary , token_remove_regex )
377+ new_tokenizer = replace_vocabulary (backend_tokenizer , cleaned_vocabulary , unk_token , pad_token )
383378
384379 return PreTrainedTokenizerFast (tokenizer_object = new_tokenizer )
0 commit comments