Skip to content

Commit ec376c1

Browse files
committed
store all relevant info in safetensors
1 parent 75decb5 commit ec376c1

6 files changed

Lines changed: 32 additions & 23 deletions

File tree

model2vec/hf_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def save_pretrained(
2525
create_model_card: bool = True,
2626
subfolder: str | None = None,
2727
weights: np.ndarray | None = None,
28+
mapping: np.ndarray | None = None,
2829
**kwargs: Any,
2930
) -> None:
3031
"""
@@ -37,6 +38,7 @@ def save_pretrained(
3738
:param create_model_card: Whether to create a model card.
3839
:param subfolder: The subfolder to save the model in.
3940
:param weights: The weights of the model. If None, no weights are saved.
41+
:param mapping: The token mapping of the model. If None, there is no token mapping.
4042
:param **kwargs: Any additional arguments.
4143
"""
4244
folder_path = folder_path / subfolder if subfolder else folder_path
@@ -45,6 +47,8 @@ def save_pretrained(
4547
model_weights = {"embeddings": embeddings}
4648
if weights is not None:
4749
model_weights["weights"] = weights
50+
if mapping is not None:
51+
model_weights["mapping"] = mapping
4852

4953
save_file(model_weights, folder_path / "model.safetensors")
5054
tokenizer.save(str(folder_path / "tokenizer.json"), pretty=False)
@@ -106,7 +110,7 @@ def load_pretrained(
106110
subfolder: str | None = None,
107111
token: str | None = None,
108112
from_sentence_transformers: bool = False,
109-
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any], np.ndarray | None]:
113+
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any], np.ndarray | None, np.ndarray | None]:
110114
"""
111115
Loads a pretrained model from a folder.
112116
@@ -185,18 +189,23 @@ def load_pretrained(
185189
if from_sentence_transformers:
186190
embeddings = opened_tensor_file.get_tensor("embedding.weight")
187191
weights = None
192+
mapping = None
188193
else:
189194
embeddings = opened_tensor_file.get_tensor("embeddings")
190195
try:
191196
weights = opened_tensor_file.get_tensor("weights")
192197
except Exception:
193198
# Bare except because safetensors does not export its own errors.
194199
weights = None
200+
try:
201+
mapping = opened_tensor_file.get_tensor("mapping")
202+
except Exception:
203+
mapping = None
195204

196205
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
197206
config = json.load(open(config_path))
198207

199-
return embeddings, tokenizer, config, metadata, weights
208+
return embeddings, tokenizer, config, metadata, weights, mapping
200209

201210

202211
def _get_metadata_from_readme(readme_path: Path) -> dict[str, Any]:

model2vec/model.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
base_model_name: str | None = None,
3131
language: list[str] | None = None,
3232
weights: np.ndarray | None = None,
33-
token_mapping: list[int] | None = None,
33+
token_mapping: np.ndarray | None = None,
3434
) -> None:
3535
"""
3636
Initialize the StaticModel.
@@ -63,7 +63,7 @@ def __init__(
6363
self.weights = weights
6464
# Convert to an array for fast lookups
6565
# We can't use or short circuit here because np.ndarray as booleans are ambiguous.
66-
self.token_mapping: np.ndarray | None = None if token_mapping is None else np.asarray(token_mapping)
66+
self.token_mapping: np.ndarray | None = token_mapping
6767

6868
self.tokenizer = tokenizer
6969
self.unk_token_id: int | None
@@ -121,9 +121,6 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
121121
"""
122122
from model2vec.hf_utils import save_pretrained
123123

124-
if self.token_mapping is not None:
125-
self.config["token_mapping"] = self.token_mapping.tolist()
126-
127124
save_pretrained(
128125
folder_path=Path(path),
129126
embeddings=self.embedding,
@@ -134,6 +131,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
134131
model_name=model_name,
135132
subfolder=subfolder,
136133
weights=self.weights,
134+
mapping=self.token_mapping,
137135
)
138136

139137
def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> list[list[int]]:
@@ -490,11 +488,10 @@ def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
490488
if not path.is_dir():
491489
raise ValueError(f"Path {path} is not a directory.")
492490

493-
embeddings, tokenizer, config, weights = load_local_model(path)
494-
token_mapping = cast(list[int], config.pop("token_mapping", None))
491+
embeddings, tokenizer, config, weights, mapping = load_local_model(path)
495492

496493
return StaticModel(
497-
vectors=embeddings, tokenizer=tokenizer, config=config, weights=weights, token_mapping=token_mapping
494+
vectors=embeddings, tokenizer=tokenizer, config=config, weights=weights, token_mapping=mapping
498495
)
499496

500497

@@ -517,7 +514,7 @@ def quantize_model(
517514
"""
518515
from model2vec.quantization import quantize_and_reduce_dim
519516

520-
token_mapping: list[int] | None
517+
token_mapping: np.ndarray | None
521518
weights: np.ndarray | None
522519
if vocabulary_quantization is not None:
523520
from model2vec.vocabulary_quantization import quantize_vocabulary
@@ -530,7 +527,7 @@ def quantize_model(
530527
)
531528
else:
532529
embeddings = model.embedding
533-
token_mapping = cast(list[int], model.token_mapping.tolist()) if model.token_mapping is not None else None
530+
token_mapping = model.token_mapping
534531
weights = model.weights
535532
if quantize_to is not None or dimensionality is not None:
536533
embeddings = quantize_and_reduce_dim(
@@ -568,20 +565,18 @@ def _loading_helper(
568565
if from_sentence_transformers and subfolder is not None:
569566
raise ValueError("Subfolder is not supported for sentence transformers models.")
570567

571-
embeddings, tokenizer, config, metadata, weights = load_pretrained(
568+
embeddings, tokenizer, config, metadata, weights, mapping = load_pretrained(
572569
folder_or_repo_path=path,
573570
token=token,
574571
from_sentence_transformers=from_sentence_transformers,
575572
subfolder=subfolder,
576573
)
577574

578-
token_mapping = config.pop("token_mapping", None)
579-
580575
model = cls(
581576
vectors=embeddings,
582577
tokenizer=tokenizer,
583578
weights=weights,
584-
token_mapping=token_mapping,
579+
token_mapping=mapping,
585580
config=config,
586581
normalize=normalize,
587582
base_model_name=metadata.get("base_model"),

model2vec/train/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def to_static_model(self) -> StaticModel:
150150
"""Convert the model to a static model."""
151151
emb = self.embeddings.weight.detach().cpu().numpy()
152152
w = torch.sigmoid(self.w).detach().cpu().numpy()
153-
token_mapping = self.token_mapping.tolist()
153+
token_mapping = self.token_mapping.numpy()
154154

155155
return StaticModel(
156156
vectors=emb, weights=w, tokenizer=self.tokenizer, normalize=True, token_mapping=token_mapping

model2vec/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def setup_logging() -> None:
104104
)
105105

106106

107-
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str], np.ndarray | None]:
107+
def load_local_model(
108+
folder: Path,
109+
) -> tuple[np.ndarray, Tokenizer, dict[str, str], np.ndarray | None, np.ndarray | None]:
108110
"""Load a local model."""
109111
embeddings_path = folder / "model.safetensors"
110112
tokenizer_path = folder / "tokenizer.json"
@@ -117,6 +119,10 @@ def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str
117119
except Exception:
118120
# Bare except because safetensors does not export its own errors.
119121
weights = None
122+
try:
123+
mapping = opened_tensor_file.get_tensor("mapping")
124+
except Exception:
125+
mapping = None
120126

121127
if config_path.exists():
122128
config = json.load(open(config_path))
@@ -125,4 +131,4 @@ def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str
125131

126132
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
127133

128-
return embeddings, tokenizer, config, weights
134+
return embeddings, tokenizer, config, weights, mapping

model2vec/vocabulary_quantization.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
def quantize_vocabulary(
2222
n_clusters: int, weights: np.ndarray | None, embeddings: np.ndarray
23-
) -> tuple[np.ndarray, list[int], np.ndarray]:
23+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
2424
"""Quantize the vocabulary of embeddings using KMeans clustering."""
2525
logger.info(f"Quantizing vocabulary to {n_clusters} clusters.")
2626
# If the model does not have weights, we assume the norm to be informative.
@@ -38,8 +38,7 @@ def quantize_vocabulary(
3838
# Fit KMeans to the embeddings
3939
kmeans.fit(cast_embeddings)
4040
# Create a mapping from the original token index to the cluster index
41-
# Make sure to convert to list, otherwise we get np.int32 which is not jsonable.
42-
token_mapping = cast(list[int], kmeans.predict(cast_embeddings).tolist())
41+
token_mapping = kmeans.predict(cast_embeddings)
4342
# The cluster centers are the new embeddings.
4443
# Convert them back to the original dtype
4544
embeddings = kmeans.cluster_centers_.astype(orig_dtype)

tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_local_load(mock_tokenizer: Tokenizer, config: dict[str, Any], expected:
9898
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
9999
if config is not None:
100100
json.dump(config, open(tempdir_path / "config.json", "w"))
101-
arr, tokenizer, config, weights = load_local_model(tempdir_path)
101+
arr, tokenizer, config, weights, _ = load_local_model(tempdir_path)
102102
assert config == expected
103103
assert tokenizer.to_str() == mock_tokenizer.to_str()
104104
assert arr.shape == x.shape

0 commit comments

Comments
 (0)