Skip to content

Commit 88c875f

Browse files
committed
block cache sketch
1 parent 43174a1 commit 88c875f

6 files changed

Lines changed: 563 additions & 1 deletion

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ classifiers = [
3030
"Programming Language :: Python :: Implementation :: PyPy",
3131
]
3232
dependencies = [
33+
"cachetools>=7.0.0",
3334
"obspec",
3435
"obstore",
3536
]

src/obspec_utils/kyle/__init__.py

Whitespace-only changes.
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
from __future__ import annotations
2+
3+
from dataclasses import dataclass, field
4+
from typing import TYPE_CHECKING, Protocol
5+
6+
from cachetools import LRUCache
7+
from obspec import GetRange, GetRangeAsync, GetRanges, GetRangesAsync
8+
9+
if TYPE_CHECKING:
10+
from collections.abc import Buffer, Sequence
11+
12+
13+
class GetRangeAndGetRanges(GetRange, GetRanges, Protocol):
14+
"""Protocol for backends supporting both GetRange and GetRanges."""
15+
16+
pass
17+
18+
19+
class GetRangeAsyncAndGetRangesAsync(GetRangeAsync, GetRangesAsync, Protocol):
20+
"""Protocol for backends supporting both GetRangeAsync and GetRangesAsync."""
21+
22+
pass
23+
24+
25+
@dataclass
26+
class MemoryCache:
27+
"""Block-aligned LRU memory cache for remote data."""
28+
29+
block_size: int = 4 * 1024 * 1024 # 4 MiB
30+
max_blocks: int = 128 # 512 MiB default
31+
32+
# (path, block_index) -> block_data (may be smaller than block_size at EOF)
33+
_blocks: LRUCache[tuple[str, int], bytes] = field(init=False)
34+
35+
def __post_init__(self) -> None:
36+
self._blocks = LRUCache(maxsize=self.max_blocks)
37+
38+
def _block_index(self, offset: int) -> int:
39+
"""Which block contains this byte offset."""
40+
return offset // self.block_size
41+
42+
def _block_start(self, block_idx: int) -> int:
43+
"""Starting byte offset of a block."""
44+
return block_idx * self.block_size
45+
46+
def get(self, path: str, start: int, end: int) -> bytes | list[tuple[int, int]]:
47+
"""Get data from cache, or return missing ranges to fetch.
48+
49+
Returns:
50+
bytes if fully cached, or list of (start, end) ranges that need fetching.
51+
Missing ranges are block-aligned and coalesced based on COALESCE_BLOCKS.
52+
"""
53+
start_block = self._block_index(start)
54+
end_block = self._block_index(end - 1) # -1 because end is exclusive
55+
56+
# First pass: identify which blocks are missing
57+
missing_blocks: list[int] = []
58+
hit_eof = False
59+
60+
for block_idx in range(start_block, end_block + 1):
61+
key = (path, block_idx)
62+
if key not in self._blocks:
63+
if not hit_eof:
64+
missing_blocks.append(block_idx)
65+
else:
66+
# Check if this cached block is partial (EOF marker)
67+
if len(self._blocks[key]) < self.block_size:
68+
hit_eof = True
69+
70+
if missing_blocks:
71+
return self._coalesce_missing_blocks(missing_blocks)
72+
73+
# All blocks cached - assemble result
74+
result = bytearray(end - start)
75+
result_offset = 0
76+
77+
for block_idx in range(start_block, end_block + 1):
78+
block_data = self._blocks[(path, block_idx)]
79+
block_start = self._block_start(block_idx)
80+
81+
# Calculate slice within this block
82+
slice_start = max(0, start - block_start)
83+
slice_end = min(len(block_data), end - block_start)
84+
chunk = block_data[slice_start:slice_end]
85+
86+
result[result_offset : result_offset + len(chunk)] = chunk
87+
result_offset += len(chunk)
88+
89+
# If this block is smaller than block_size, we hit EOF
90+
if len(block_data) < self.block_size:
91+
break
92+
93+
# Truncate if we hit EOF before filling the buffer
94+
return bytes(result[:result_offset])
95+
96+
def _coalesce_missing_blocks(
97+
self, missing_blocks: list[int]
98+
) -> list[tuple[int, int]]:
99+
"""Coalesce consecutive missing blocks into ranges.
100+
101+
Adjacent missing blocks are always coalesced. Non-adjacent missing blocks
102+
(with cached blocks in between) are kept as separate ranges to avoid
103+
re-fetching cached data.
104+
"""
105+
if not missing_blocks:
106+
return []
107+
108+
ranges: list[tuple[int, int]] = []
109+
range_start = missing_blocks[0]
110+
range_end = missing_blocks[0]
111+
112+
for block_idx in missing_blocks[1:]:
113+
# Only coalesce if blocks are adjacent (gap of 1 means consecutive)
114+
if block_idx - range_end == 1:
115+
range_end = block_idx
116+
else:
117+
# There's a gap (cached block in between), start new range
118+
ranges.append(
119+
(
120+
self._block_start(range_start),
121+
self._block_start(range_end + 1),
122+
)
123+
)
124+
range_start = block_idx
125+
range_end = block_idx
126+
127+
# Don't forget the last range
128+
ranges.append(
129+
(
130+
self._block_start(range_start),
131+
self._block_start(range_end + 1),
132+
)
133+
)
134+
135+
return ranges
136+
137+
def store(self, path: str, fetch_start: int, data: Buffer) -> None:
138+
"""Store fetched data as blocks. fetch_start must be block-aligned.
139+
140+
The last block may be smaller than block_size if we hit EOF.
141+
"""
142+
assert fetch_start % self.block_size == 0, "fetch_start must be block-aligned"
143+
144+
data_bytes = bytes(data)
145+
offset = 0
146+
block_idx = fetch_start // self.block_size
147+
148+
while offset < len(data_bytes):
149+
block_data = data_bytes[offset : offset + self.block_size]
150+
self._blocks[(path, block_idx)] = block_data
151+
offset += self.block_size
152+
block_idx += 1
153+
154+
155+
@dataclass
156+
class SyncBlockCache:
157+
"""Synchronous block cache wrapping a GetRange backend."""
158+
159+
backend: GetRangeAndGetRanges
160+
cache: MemoryCache = field(default_factory=MemoryCache)
161+
162+
def get_range(
163+
self,
164+
path: str,
165+
*,
166+
start: int,
167+
end: int | None = None,
168+
length: int | None = None,
169+
) -> bytes:
170+
if end is None:
171+
if length is None:
172+
raise ValueError("Either end or length must be provided")
173+
end = start + length
174+
175+
result = self.cache.get(path, start, end)
176+
if isinstance(result, list):
177+
# result is list of missing ranges - fetch them
178+
self._fetch_missing(path, result)
179+
# Now should be cached
180+
result = self.cache.get(path, start, end)
181+
assert isinstance(result, bytes)
182+
183+
return result
184+
185+
def get_ranges(
186+
self,
187+
path: str,
188+
*,
189+
starts: Sequence[int],
190+
ends: Sequence[int] | None = None,
191+
lengths: Sequence[int] | None = None,
192+
) -> Sequence[bytes]:
193+
"""Return the bytes stored at the specified location in the given byte ranges."""
194+
if ends is None:
195+
if lengths is None:
196+
raise ValueError("Either ends or lengths must be provided")
197+
ends = [s + length for s, length in zip(starts, lengths)]
198+
199+
# Collect all missing ranges across all requests
200+
all_missing: list[tuple[int, int]] = []
201+
for start, end in zip(starts, ends):
202+
result = self.cache.get(path, start, end)
203+
if isinstance(result, list):
204+
all_missing.extend(result)
205+
206+
# Fetch all missing ranges in one batch
207+
if all_missing:
208+
self._fetch_missing(path, all_missing)
209+
210+
# Now all should be cached - collect results
211+
results: list[bytes] = []
212+
for start, end in zip(starts, ends):
213+
result = self.cache.get(path, start, end)
214+
assert isinstance(result, bytes)
215+
results.append(result)
216+
217+
return results
218+
219+
def _fetch_missing(self, path: str, ranges: list[tuple[int, int]]) -> None:
220+
"""Fetch missing ranges from backend and store in cache."""
221+
if len(ranges) == 1:
222+
start, end = ranges[0]
223+
data = self.backend.get_range(path, start=start, end=end)
224+
self.cache.store(path, start, data)
225+
else:
226+
starts = [r[0] for r in ranges]
227+
ends = [r[1] for r in ranges]
228+
buffers: Sequence[Buffer] = self.backend.get_ranges(
229+
path, starts=starts, ends=ends
230+
)
231+
for (range_start, _), data in zip(ranges, buffers):
232+
self.cache.store(path, range_start, data)
233+
234+
235+
@dataclass
236+
class AsyncBlockCache(GetRangeAsync, GetRangesAsync):
237+
"""Async block cache wrapping a GetRangeAsync backend."""
238+
239+
backend: GetRangeAsyncAndGetRangesAsync
240+
cache: MemoryCache = field(default_factory=MemoryCache)
241+
242+
async def get_range_async(
243+
self,
244+
path: str,
245+
*,
246+
start: int,
247+
end: int | None = None,
248+
length: int | None = None,
249+
) -> bytes:
250+
if end is None:
251+
if length is None:
252+
raise ValueError("Either end or length must be provided")
253+
end = start + length
254+
255+
result = self.cache.get(path, start, end)
256+
if isinstance(result, list):
257+
# result is list of missing ranges - fetch them
258+
await self._fetch_missing(path, result)
259+
# Now should be cached
260+
result = self.cache.get(path, start, end)
261+
assert isinstance(result, bytes)
262+
263+
return result
264+
265+
async def get_ranges_async(
266+
self,
267+
path: str,
268+
*,
269+
starts: Sequence[int],
270+
ends: Sequence[int] | None = None,
271+
lengths: Sequence[int] | None = None,
272+
) -> Sequence[bytes]:
273+
"""Return the bytes stored at the specified location in the given byte ranges."""
274+
if ends is None:
275+
if lengths is None:
276+
raise ValueError("Either ends or lengths must be provided")
277+
ends = [s + length for s, length in zip(starts, lengths)]
278+
279+
# Collect all missing ranges across all requests
280+
all_missing: list[tuple[int, int]] = []
281+
for start, end in zip(starts, ends):
282+
result = self.cache.get(path, start, end)
283+
if isinstance(result, list):
284+
all_missing.extend(result)
285+
286+
# Fetch all missing ranges in one batch
287+
if all_missing:
288+
await self._fetch_missing(path, all_missing)
289+
290+
# Now all should be cached - collect results
291+
results: list[bytes] = []
292+
for start, end in zip(starts, ends):
293+
result = self.cache.get(path, start, end)
294+
assert isinstance(result, bytes)
295+
results.append(result)
296+
297+
return results
298+
299+
async def _fetch_missing(self, path: str, ranges: list[tuple[int, int]]) -> None:
300+
"""Fetch missing ranges from backend and store in cache."""
301+
if len(ranges) == 1:
302+
start, end = ranges[0]
303+
data = await self.backend.get_range_async(path, start=start, end=end)
304+
self.cache.store(path, start, data)
305+
else:
306+
starts = [r[0] for r in ranges]
307+
ends = [r[1] for r in ranges]
308+
buffers: Sequence[Buffer] = await self.backend.get_ranges_async(
309+
path, starts=starts, ends=ends
310+
)
311+
for (range_start, _), data in zip(ranges, buffers):
312+
self.cache.store(path, range_start, data)

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from pathlib import Path
44

55
import pytest
6-
import xarray as xr
76

87

98
def pytest_addoption(parser):
@@ -105,6 +104,8 @@ def minio_bucket(container):
105104
@pytest.fixture
106105
def local_netcdf4_file(tmp_path: Path) -> str:
107106
"""Create a NetCDF4 file with data in multiple groups."""
107+
import xarray as xr
108+
108109
filepath = tmp_path / "test.nc"
109110
ds1 = xr.DataArray([1, 2, 3], name="foo").to_dataset()
110111
ds1.to_netcdf(filepath)

tests/kyle/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)