Skip to content

Commit e666b4b

Browse files
committed
feat: faster loading if model already cached
1 parent 55b955a commit e666b4b

6 files changed

Lines changed: 58 additions & 133 deletions

File tree

model2vec/distill/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_embeddings(
4646
:param pad_token_id: The pad token id. Used to pad sequences.
4747
:return: The output embeddings.
4848
"""
49-
model = model.to(device)
49+
model = model.to(device) # type: ignore
5050

5151
out_weights: np.ndarray
5252
intermediate_weights: list[np.ndarray] = []
@@ -98,7 +98,7 @@ def _encode_mean_using_model(model: PreTrainedModel, encodings: dict[str, torch.
9898
"""
9999
encodings = {k: v.to(model.device) for k, v in encodings.items()}
100100
encoded: BaseModelOutputWithPoolingAndCrossAttentions = model(**encodings)
101-
out: torch.Tensor = encoded.last_hidden_state.cpu()
101+
out: torch.Tensor = encoded.last_hidden_state.cpu() # type: ignore # typing is wrong.
102102
# NOTE: If the dtype is bfloat 16, we convert to float32,
103103
# because numpy does not suport bfloat16
104104
# See here: https://github.com/numpy/numpy/issues/19808

model2vec/hf_utils.py

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import numpy as np
1010
import safetensors
1111
from huggingface_hub import ModelCard, ModelCardData
12+
from huggingface_hub.constants import HF_HUB_CACHE
1213
from safetensors.numpy import save_file
1314
from tokenizers import Tokenizer
1415

@@ -99,6 +100,7 @@ def load_pretrained(
99100
subfolder: str | None = None,
100101
token: str | None = None,
101102
from_sentence_transformers: bool = False,
103+
skip_metadata: bool = False,
102104
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
103105
"""
104106
Loads a pretrained model from a folder.
@@ -109,6 +111,7 @@ def load_pretrained(
109111
:param subfolder: The subfolder to load from.
110112
:param token: The huggingface token to use.
111113
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114+
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
112115
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113116
:return: The embeddings, tokenizer, config, and metadata.
114117
@@ -122,7 +125,12 @@ def load_pretrained(
122125
tokenizer_file = "tokenizer.json"
123126
config_name = "config.json"
124127

125-
folder_or_repo_path = Path(folder_or_repo_path)
128+
if cached_folder := _get_latest_model_path(str(folder_or_repo_path)):
129+
logger.info(f"Found cached model at {cached_folder}, loading from cache.")
130+
folder_or_repo_path = cached_folder
131+
else:
132+
logger.info(f"No cached model found for {folder_or_repo_path}, loading from local or hub.")
133+
folder_or_repo_path = Path(folder_or_repo_path)
126134

127135
local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128136

@@ -139,9 +147,7 @@ def load_pretrained(
139147
if not tokenizer_path.exists():
140148
raise FileNotFoundError(f"Tokenizer file does not exist in {local_folder}")
141149

142-
# README is optional, so this is a bit finicky.
143150
readme_path = local_folder / "README.md"
144-
metadata = _get_metadata_from_readme(readme_path)
145151

146152
else:
147153
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
@@ -150,18 +156,11 @@ def load_pretrained(
150156
folder_or_repo_path.as_posix(), model_file, token=token, subfolder=subfolder
151157
)
152158
)
153-
154-
try:
155-
readme_path = Path(
156-
huggingface_hub.hf_hub_download(
157-
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
158-
)
159+
readme_path = Path(
160+
huggingface_hub.hf_hub_download(
161+
folder_or_repo_path.as_posix(), "README.md", token=token, subfolder=subfolder
159162
)
160-
metadata = _get_metadata_from_readme(Path(readme_path))
161-
except Exception as e:
162-
# NOTE: we don't want to raise an error here, since the README is optional.
163-
logger.info(f"No README found in the model folder: {e} No model card loaded.")
164-
metadata = {}
163+
)
165164

166165
config_path = Path(
167166
huggingface_hub.hf_hub_download(
@@ -175,10 +174,13 @@ def load_pretrained(
175174
)
176175

177176
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
178-
if from_sentence_transformers:
179-
embeddings = opened_tensor_file.get_tensor("embedding.weight")
177+
embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
178+
embeddings = opened_tensor_file.get_tensor(embedding_key)
179+
180+
if not skip_metadata and readme_path.exists():
181+
metadata = _get_metadata_from_readme(readme_path)
180182
else:
181-
embeddings = opened_tensor_file.get_tensor("embeddings")
183+
metadata = {}
182184

183185
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
184186
config = json.load(open(config_path))
@@ -223,3 +225,28 @@ def push_folder_to_hub(
223225
huggingface_hub.upload_folder(repo_id=repo_id, folder_path=folder_path, token=token, path_in_repo=subfolder)
224226

225227
logger.info(f"Pushed model to {repo_id}")
228+
229+
230+
def _get_latest_model_path(model_id: str) -> Path | None:
231+
"""
232+
Gets the latest model path for a given identifier from the hugging face hub cache.
233+
234+
Returns None if there is no cached model. In this case, the model will be downloaded.
235+
"""
236+
# Make path object
237+
cache_dir = Path(HF_HUB_CACHE)
238+
# This is specific to how HF stores the files.
239+
normalized = model_id.replace("/", "--")
240+
repo_dir = cache_dir / f"models--{normalized}" / "snapshots"
241+
242+
if not repo_dir.exists():
243+
return None
244+
245+
# Find all directories.
246+
snapshots = [p for p in repo_dir.iterdir() if p.is_dir()]
247+
if not snapshots:
248+
return None
249+
250+
# Get the latest directory by modification time.
251+
latest_snapshot = max(snapshots, key=lambda p: p.stat().st_mtime)
252+
return latest_snapshot

model2vec/model.py

Lines changed: 9 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from tqdm import tqdm
1414

1515
from model2vec.quantization import DType, quantize_and_reduce_dim
16-
from model2vec.utils import ProgressParallel, load_local_model
16+
from model2vec.utils import ProgressParallel
1717

1818
PathLike = Union[Path, str]
1919

@@ -156,6 +156,7 @@ def from_pretrained(
156156
subfolder: str | None = None,
157157
quantize_to: str | DType | None = None,
158158
dimensionality: int | None = None,
159+
skip_metadata: bool = False,
159160
) -> StaticModel:
160161
"""
161162
Load a StaticModel from a local path or huggingface hub path.
@@ -171,6 +172,8 @@ def from_pretrained(
171172
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
172173
This is useful if you want to load a model with a lower dimensionality.
173174
Note that this only applies if you have trained your model using mrl or PCA.
175+
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
176+
Loading metadata can be slow for models with lots of results in the README.md
174177
:return: A StaticModel.
175178
"""
176179
from model2vec.hf_utils import load_pretrained
@@ -180,6 +183,7 @@ def from_pretrained(
180183
token=token,
181184
from_sentence_transformers=False,
182185
subfolder=subfolder,
186+
skip_metadata=skip_metadata,
183187
)
184188

185189
embeddings = quantize_and_reduce_dim(
@@ -205,6 +209,7 @@ def from_sentence_transformers(
205209
normalize: bool | None = None,
206210
quantize_to: str | DType | None = None,
207211
dimensionality: int | None = None,
212+
skip_metadata: bool = False,
208213
) -> StaticModel:
209214
"""
210215
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -219,6 +224,8 @@ def from_sentence_transformers(
219224
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
220225
This is useful if you want to load a model with a lower dimensionality.
221226
Note that this only applies if you have trained your model using mrl or PCA.
227+
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
228+
Loading metadata can be slow for models with lots of results in the README.md
222229
:return: A StaticModel.
223230
"""
224231
from model2vec.hf_utils import load_pretrained
@@ -228,6 +235,7 @@ def from_sentence_transformers(
228235
token=token,
229236
from_sentence_transformers=True,
230237
subfolder=None,
238+
skip_metadata=skip_metadata,
231239
)
232240

233241
embeddings = quantize_and_reduce_dim(
@@ -447,28 +455,3 @@ def push_to_hub(
447455
with TemporaryDirectory() as temp_dir:
448456
self.save_pretrained(temp_dir, model_name=repo_id)
449457
push_folder_to_hub(Path(temp_dir), subfolder=subfolder, repo_id=repo_id, private=private, token=token)
450-
451-
@classmethod
452-
def load_local(cls: type[StaticModel], path: PathLike) -> StaticModel:
453-
"""
454-
Loads a model from a local path.
455-
456-
You should only use this code path if you are concerned with start-up time.
457-
Loading via the `from_pretrained` method is safer, and auto-downloads, but
458-
also means we import a whole bunch of huggingface code that we don't need.
459-
460-
Additionally, huggingface will check the most recent version of the model,
461-
which can be slow.
462-
463-
:param path: The path to load the model from. The path is a directory saved by the
464-
`save_pretrained` method.
465-
:return: A StaticModel
466-
:raises: ValueError if the path is not a directory.
467-
"""
468-
path = Path(path)
469-
if not path.is_dir():
470-
raise ValueError(f"Path {path} is not a directory.")
471-
472-
embeddings, tokenizer, config = load_local_model(path)
473-
474-
return StaticModel(embeddings, tokenizer, config)

model2vec/utils.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -102,27 +102,3 @@ def setup_logging() -> None:
102102
datefmt="%Y-%m-%d %H:%M:%S",
103103
handlers=[RichHandler(rich_tracebacks=True)],
104104
)
105-
106-
107-
def load_local_model(folder: Path) -> tuple[np.ndarray, Tokenizer, dict[str, str]]:
108-
"""Load a local model."""
109-
embeddings_path = folder / "model.safetensors"
110-
tokenizer_path = folder / "tokenizer.json"
111-
config_path = folder / "config.json"
112-
113-
opened_tensor_file = cast(SafeOpenProtocol, safetensors.safe_open(embeddings_path, framework="numpy"))
114-
embeddings = opened_tensor_file.get_tensor("embeddings")
115-
116-
if config_path.exists():
117-
config = json.load(open(config_path))
118-
else:
119-
config = {}
120-
121-
tokenizer: Tokenizer = Tokenizer.from_file(str(tokenizer_path))
122-
123-
if len(tokenizer.get_vocab()) != len(embeddings):
124-
logger.warning(
125-
f"Number of tokens does not match number of embeddings: `{len(tokenizer.get_vocab())}` vs `{len(embeddings)}`"
126-
)
127-
128-
return embeddings, tokenizer, config

tests/test_model.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def test_encode_as_tokens_empty(
118118
encoded = model.encode_as_sequence("")
119119
assert np.array_equal(encoded, np.zeros(shape=(0, 2), dtype=model.embedding.dtype))
120120

121-
encoded = model.encode_as_sequence(["", ""])
121+
encoded_list = model.encode_as_sequence(["", ""])
122122
out = [np.zeros(shape=(0, 2), dtype=model.embedding.dtype) for _ in range(2)]
123-
assert [np.array_equal(x, y) for x, y in zip(encoded, out)]
123+
assert [np.array_equal(x, y) for x, y in zip(encoded_list, out)]
124124

125125

126126
def test_encode_empty_sentence(
@@ -273,23 +273,3 @@ def test_dim(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: d
273273
model = StaticModel(mock_vectors, mock_tokenizer, mock_config)
274274
assert model.dim == 2
275275
assert model.dim == model.embedding.shape[1]
276-
277-
278-
def test_local_load_from_model(mock_tokenizer: Tokenizer) -> None:
279-
"""Test local load from a model."""
280-
x = np.ones((mock_tokenizer.get_vocab_size(), 2))
281-
with TemporaryDirectory() as tempdir:
282-
tempdir_path = Path(tempdir)
283-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
284-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
285-
286-
model = StaticModel.load_local(tempdir_path)
287-
assert model.embedding.shape == x.shape
288-
assert model.tokenizer.to_str() == mock_tokenizer.to_str()
289-
assert model.config == {"normalize": False}
290-
291-
292-
def test_local_load_from_model_no_folder() -> None:
293-
"""Test local load from a model with no folder."""
294-
with pytest.raises(ValueError):
295-
StaticModel.load_local("woahbuddy_relax_this_is_just_a_test")

tests/test_utils.py

Lines changed: 1 addition & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from model2vec.distill.utils import select_optimal_device
1616
from model2vec.hf_utils import _get_metadata_from_readme
17-
from model2vec.utils import get_package_extras, importable, load_local_model
17+
from model2vec.utils import get_package_extras, importable
1818

1919

2020
def test__get_metadata_from_readme_not_exists() -> None:
@@ -78,44 +78,3 @@ def test_get_package_extras() -> None:
7878
def test_get_package_extras_empty() -> None:
7979
"""Test package extras with an empty package."""
8080
assert not list(get_package_extras("tqdm", ""))
81-
82-
83-
@pytest.mark.parametrize(
84-
"config, expected",
85-
[
86-
({"dog": "cat"}, {"dog": "cat"}),
87-
({}, {}),
88-
(None, {}),
89-
],
90-
)
91-
def test_local_load(mock_tokenizer: Tokenizer, config: dict[str, Any], expected: dict[str, Any]) -> None:
92-
"""Test local loading."""
93-
x = np.ones((mock_tokenizer.get_vocab_size(), 2))
94-
95-
with TemporaryDirectory() as tempdir:
96-
tempdir_path = Path(tempdir)
97-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
98-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
99-
if config is not None:
100-
json.dump(config, open(tempdir_path / "config.json", "w"))
101-
arr, tokenizer, config = load_local_model(tempdir_path)
102-
assert config == expected
103-
assert tokenizer.to_str() == mock_tokenizer.to_str()
104-
assert arr.shape == x.shape
105-
106-
107-
def test_local_load_mismatch(mock_tokenizer: Tokenizer, caplog: pytest.LogCaptureFixture) -> None:
108-
"""Test local loading."""
109-
x = np.ones((10, 2))
110-
111-
with TemporaryDirectory() as tempdir:
112-
tempdir_path = Path(tempdir)
113-
safetensors.numpy.save_file({"embeddings": x}, Path(tempdir) / "model.safetensors")
114-
mock_tokenizer.save(str(Path(tempdir) / "tokenizer.json"))
115-
116-
load_local_model(tempdir_path)
117-
expected = (
118-
f"Number of tokens does not match number of embeddings: `{len(mock_tokenizer.get_vocab())}` vs `{len(x)}`"
119-
)
120-
assert len(caplog.records) == 1
121-
assert caplog.records[0].message == expected

0 commit comments

Comments
 (0)