Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions gittensor/validator/pat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"""

import json
import logging
import os
import tempfile
import threading
Expand All @@ -18,6 +19,7 @@
PATS_FILE = Path(__file__).resolve().parents[2] / 'data' / 'miner_pats.json'

_lock = threading.Lock()
_logger = logging.getLogger(__name__)


def ensure_pats_file() -> None:
Expand All @@ -34,9 +36,14 @@ def load_all_pats() -> list[dict]:


def save_pat(uid: int, hotkey: str, pat: str, github_id: str) -> None:
"""Upsert a PAT entry by UID. Creates the file if needed."""
"""Upsert a PAT entry by UID. Creates the file if needed.

Raises json.JSONDecodeError / OSError if PATS_FILE exists but is unreadable;
we refuse to overwrite a corrupt file so a partial-write or on-disk corruption
cannot permanently destroy stored PATs.
"""
with _lock:
entries = _read_file()
entries = _read_file(raise_on_corrupt=True)

entry = {
'uid': uid,
Expand Down Expand Up @@ -66,23 +73,38 @@ def get_pat_by_uid(uid: int) -> Optional[dict]:


def remove_pat(uid: int) -> bool:
"""Remove a PAT entry by UID. Returns True if an entry was removed."""
"""Remove a PAT entry by UID. Returns True if an entry was removed.

Raises json.JSONDecodeError / OSError if PATS_FILE exists but is unreadable
(same refuse-to-overwrite invariant as save_pat).
"""
with _lock:
entries = _read_file()
entries = _read_file(raise_on_corrupt=True)
filtered = [e for e in entries if e.get('uid') != uid]
if len(filtered) == len(entries):
return False
_write_file(filtered)
return True


def _read_file() -> list[dict]:
"""Read and parse the JSON file. Must be called while holding _lock."""
def _read_file(*, raise_on_corrupt: bool = False) -> list[dict]:
"""Read and parse the JSON file. Must be called while holding _lock.

Read paths (load_all_pats, get_pat_by_uid) keep the default and degrade to
an empty list with a warning, so a single corrupt file does not crash the
validator scoring round. Write paths (save_pat, remove_pat) pass
raise_on_corrupt=True so they surface the error rather than overwriting
an unreadable file with a fresh entry list.
"""
if not PATS_FILE.exists():
return []
try:
return json.loads(PATS_FILE.read_text())
except (json.JSONDecodeError, OSError):
except (json.JSONDecodeError, OSError) as e:
if raise_on_corrupt:
_logger.error('PATS_FILE %s is unreadable; refusing to overwrite: %s', PATS_FILE, e)
raise
_logger.warning('PATS_FILE %s is unreadable; treating as empty for read path: %s', PATS_FILE, e)
return []


Expand Down
27 changes: 27 additions & 0 deletions tests/validator/test_pat_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,33 @@ def test_remove_preserves_others(self):
assert entries[0]['uid'] == 2


class TestCorruptFileRefusesToOverwrite:
"""Regression tests for #781: corrupt PATS_FILE must not be silently overwritten."""

def test_save_pat_raises_when_file_is_corrupt(self, use_tmp_pats_file):
use_tmp_pats_file.write_text('{not valid json')
with pytest.raises(json.JSONDecodeError):
pat_storage.save_pat(1, 'hotkey_1', 'ghp_abc', 'user_1')

def test_save_pat_does_not_overwrite_corrupt_file_contents(self, use_tmp_pats_file):
"""The on-disk corrupt bytes must remain intact after a refused save."""
corrupt_payload = '{"foo'
use_tmp_pats_file.write_text(corrupt_payload)
with pytest.raises(json.JSONDecodeError):
pat_storage.save_pat(1, 'hotkey_1', 'ghp_abc', 'user_1')
assert use_tmp_pats_file.read_text() == corrupt_payload

def test_remove_pat_raises_when_file_is_corrupt(self, use_tmp_pats_file):
use_tmp_pats_file.write_text('garbage')
with pytest.raises(json.JSONDecodeError):
pat_storage.remove_pat(1)

def test_load_all_pats_still_returns_empty_for_corrupt_file(self, use_tmp_pats_file):
"""Read path keeps the historical graceful-degradation contract."""
use_tmp_pats_file.write_text('garbage')
assert pat_storage.load_all_pats() == []


class TestConcurrency:
def test_concurrent_writes(self):
"""Multiple threads writing simultaneously should not corrupt the file."""
Expand Down
Loading