Skip to content

Commit 3aff31b

Browse files
committed
fix tests, make tokenizer changes better
1 parent 796e18f commit 3aff31b

5 files changed

Lines changed: 55 additions & 48 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def distill_from_model(
7373

7474
n_tokens_before = len(vocabulary)
7575
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
76-
all_tokens = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex=token_remove_regex)
76+
all_tokens, backend_tokenizer = clean_and_create_vocabulary(
77+
tokenizer, vocabulary, token_remove_regex=token_remove_regex
78+
)
7779
n_tokens_after = len([token for token in all_tokens if not token.is_internal])
7880
if n_tokens_before:
7981
logger.info(

model2vec/tokenizer/normalizer.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
from string import punctuation
22

3-
from tokenizers import Regex
3+
from tokenizers import Regex, Tokenizer
44
from tokenizers.normalizers import Normalizer, Replace, Sequence, Strip
55

66

7-
def prepare_normalizer(
8-
normalizer: Normalizer,
9-
) -> Normalizer:
7+
def replace_normalizer(
8+
tokenizer: Tokenizer,
9+
) -> Tokenizer:
1010
"""
11-
Prepare the normalizer for the tokenizer.
11+
Replace the normalizer for the tokenizer.
1212
13-
This function sets the normalizer for the tokenizer based on the provided normalizer type.
14-
If no normalizer is provided, it uses the default one.
13+
The new normalizer will replace punctuation with a space before and after the punctuation.
14+
It will also replace multiple spaces with a single space and strip the right side of the string.
15+
If the tokenizer already has a normalizer, it will be added to the new normalizer.
16+
If the tokenizer does not have a normalizer, a new normalizer will be created.
1517
16-
:param normalizer: The tokenizer to prepare.
17-
:return: The prepared tokenizer.
18+
:param tokenizer: The tokenizer to change.
19+
:return: The tokenizer with a replaced normalizer.
1820
"""
21+
normalizer = tokenizer.normalizer
1922
new_normalizers = []
2023
for char in punctuation:
2124
new_normalizers.append(Replace(char, f" {char} "))
2225

2326
new_normalizers.append(Replace(Regex(r"\s+"), " "))
2427
new_normalizers.append(Strip(right=True))
2528
if normalizer is None:
26-
return Sequence(new_normalizers)
29+
normalizer = Sequence(new_normalizers)
30+
else:
31+
normalizer = Sequence([normalizer] + new_normalizers) # type: ignore
32+
tokenizer.normalizer = normalizer
2733

28-
return Sequence([normalizer] + new_normalizers) # type: ignore
34+
return tokenizer

model2vec/tokenizer/pretokenizer.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
22

3+
import json
34
from typing import Any
45

6+
from tokenizers import Tokenizer
7+
58
_FORBIDDEN_PRETOKENIZERS = (
69
"WhiteSpace",
710
"WhitespaceSplit",
@@ -28,26 +31,27 @@ def _fix_single_pretokenizer(pre_tokenizer: dict[str, Any]) -> dict[str, Any] |
2831
return pre_tokenizer
2932

3033

31-
def fix_pretokenizer(pretokenizer: dict[str, Any] | None) -> dict[str, Any]:
34+
def replace_pretokenizer(tokenizer: Tokenizer) -> Tokenizer:
3235
"""Fixes a single pretokenizer to allow multiword units."""
33-
if pretokenizer is None:
34-
return _BASIC_METASPACE
36+
tokenizer_json = json.loads(tokenizer.to_str())
37+
pre_tokenizer_json = tokenizer_json.get("pre_tokenizer", None)
38+
39+
if pre_tokenizer_json is None:
40+
pre_tokenizer_json = _BASIC_METASPACE
3541

36-
if pretokenizer["type"] == "Sequence":
42+
elif pre_tokenizer_json["type"] == "Sequence":
3743
new_pretokenizers = []
38-
for single_pretokenizer in pretokenizer["pretokenizers"]:
44+
for single_pretokenizer in pre_tokenizer_json["pretokenizers"]:
3945
new_pretokenizer = _fix_single_pretokenizer(single_pretokenizer)
4046
if new_pretokenizer is not None:
4147
new_pretokenizers.append(new_pretokenizer)
42-
pretokenizer["pretokenizers"] = new_pretokenizers
43-
44-
if not pretokenizer:
45-
return _BASIC_METASPACE
4648

47-
return pretokenizer
49+
if new_pretokenizers:
50+
pre_tokenizer_json["pretokenizers"] = new_pretokenizers
51+
else:
52+
pre_tokenizer_json = _BASIC_METASPACE
4853

49-
single_pretokenizer = _fix_single_pretokenizer(pretokenizer)
50-
if single_pretokenizer is None:
51-
return _BASIC_METASPACE
54+
pre_tokenizer_json = _fix_single_pretokenizer(pre_tokenizer_json) or _BASIC_METASPACE
55+
tokenizer_json["pre_tokenizer"] = pre_tokenizer_json
5256

53-
return single_pretokenizer
57+
return tokenizer.from_str(json.dumps(tokenizer_json))

model2vec/tokenizer/tokenizer.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414

1515
from model2vec.tokenizer.datamodels import Token
1616
from model2vec.tokenizer.model import process_tokenizer
17-
from model2vec.tokenizer.normalizer import prepare_normalizer
18-
from model2vec.tokenizer.pretokenizer import fix_pretokenizer
17+
from model2vec.tokenizer.normalizer import replace_normalizer
18+
from model2vec.tokenizer.pretokenizer import replace_pretokenizer
1919

2020
logger = logging.getLogger(__name__)
2121

@@ -54,11 +54,7 @@ def replace_vocabulary(
5454
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
5555
) -> Tokenizer:
5656
"""Replace the vocabulary of a tokenizer with a new one."""
57-
tokenizer = tokenizer.from_str(tokenizer.to_str())
58-
tokenizer.normalizer = prepare_normalizer(tokenizer.normalizer) # type: ignore[assignment] # Is just wrong
5957
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
60-
tokenizer_json["pre_tokenizer"] = fix_pretokenizer(tokenizer_json["pre_tokenizer"])
61-
6258
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
6359

6460
pre_tokenized_tokens = [x.normalized_form for x in new_vocabulary]
@@ -102,7 +98,7 @@ def clean_and_create_vocabulary(
10298
tokenizer: PreTrainedTokenizerFast,
10399
vocabulary: list[str],
104100
token_remove_regex: re.Pattern | None,
105-
) -> list[Token]:
101+
) -> tuple[list[Token], Tokenizer]:
106102
"""Cleans a vocabulary by removing duplicates and tokens that were already in the vocabulary."""
107103
seen_tokens = set()
108104
post_normalize_seen_tokens = set()
@@ -115,15 +111,12 @@ def clean_and_create_vocabulary(
115111
internal_vocab: dict[str, int] = tokenizer.get_vocab()
116112
internal_tokens: list[str] = [k for k, _ in sorted(internal_vocab.items(), key=lambda x: x[1])]
117113

118-
cleaned_vocabulary = _process_internal_tokens(tokenizer, internal_tokens, token_remove_regex)
119-
internal_tokens_set = {token.form for token in cleaned_vocabulary}
120-
121-
# Change the backend tokenizer to the new one.
114+
# Copy the backend tokenizer to avoid modifying the original.
122115
backend_tokenizer = backend_tokenizer.from_str(backend_tokenizer.to_str())
123-
backend_tokenizer.normalizer = prepare_normalizer(backend_tokenizer.normalizer) # type: ignore[assignment] # Is just wrong
124-
tokenizer_json: dict[str, Any] = json.loads(backend_tokenizer.to_str())
125-
tokenizer_json["pre_tokenizer"] = fix_pretokenizer(tokenizer_json["pre_tokenizer"])
126-
backend_tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
116+
backend_tokenizer = replace_normalizer(backend_tokenizer)
117+
118+
cleaned_vocabulary = _process_internal_tokens(tokenizer, backend_tokenizer, internal_tokens, token_remove_regex)
119+
internal_tokens_set = {token.form for token in cleaned_vocabulary}
127120

128121
normalizer: Normalizer | None = backend_tokenizer.normalizer
129122
for token in vocabulary:
@@ -178,11 +171,14 @@ def clean_and_create_vocabulary(
178171
if n_empty:
179172
logger.warning(f"Removed {n_empty} empty tokens.")
180173

181-
return cleaned_vocabulary
174+
return cleaned_vocabulary, replace_pretokenizer(backend_tokenizer)
182175

183176

184177
def _process_internal_tokens(
185-
tokenizer: PreTrainedTokenizerFast, internal_tokens: list[str], token_remove_regex: re.Pattern | None
178+
tokenizer: PreTrainedTokenizerFast,
179+
backend_tokenizer: Tokenizer,
180+
internal_tokens: list[str],
181+
token_remove_regex: re.Pattern | None,
186182
) -> list[Token]:
187183
"""Clean internal tokens."""
188184
# Get the pad and unk token from the tokenizer.
@@ -193,7 +189,6 @@ def _process_internal_tokens(
193189
added_tokens_to_remove = set(tokenizer.added_tokens_encoder) - added_tokens_to_keep
194190
cleaned_internal_tokens: list[Token] = []
195191

196-
backend_tokenizer = tokenizer.backend_tokenizer
197192
# Figure out whether token is a subword or not.
198193
encoded = backend_tokenizer.encode(f" {'a' * 25}", add_special_tokens=False)
199194
first_token, second_token, *_ = encoded.tokens
@@ -378,7 +373,7 @@ def create_tokenizer(
378373
"""
379374
unk_token = cast(str | None, tokenizer.special_tokens_map.get("unk_token"))
380375
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token"))
381-
cleaned_vocabulary = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
382-
new_tokenizer = replace_vocabulary(tokenizer.backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)
376+
cleaned_vocabulary, backend_tokenizer = clean_and_create_vocabulary(tokenizer, vocabulary, token_remove_regex)
377+
new_tokenizer = replace_vocabulary(backend_tokenizer, cleaned_vocabulary, unk_token, pad_token)
383378

384379
return PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)

tests/test_distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from huggingface_hub.errors import RepositoryNotFoundError
2323
except ImportError:
2424
# For huggingface_hub<0.25.0
25-
from huggingface_hub.utils._errors import RepositoryNotFoundError
25+
from huggingface_hub.utils._errors import RepositoryNotFoundError # type: ignore
2626

2727
rng = np.random.default_rng()
2828

@@ -275,7 +275,7 @@ def test_clean_and_create_vocabulary(
275275
) -> None:
276276
"""Test the _clean_vocabulary function."""
277277
with caplog.at_level("WARNING"):
278-
tokens = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None)
278+
tokens, _ = clean_and_create_vocabulary(mock_berttokenizer, added_tokens, None)
279279

280280
cleaned_vocab = [token.form for token in tokens if not token.is_internal]
281281
# Check the cleaned vocabulary matches the expected output

0 commit comments

Comments
 (0)