Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions gittensor/utils/github_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import fnmatch
import os
import time
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from math import ceil
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional

from gittensor.utils.utils import parse_repo_name

Expand Down Expand Up @@ -174,7 +175,51 @@ def make_headers(token: str) -> Dict[str, str]:

def make_graphql_headers(token: str) -> Dict[str, str]:
"""Build GitHub GraphQL headers for a PAT."""
return {'Authorization': f'Bearer {token}', 'Content-Type': 'application/json'}
return {
'Authorization': f'Bearer {token}',
'Content-Type': 'application/json',
'Accept': 'application/json',
}


def make_anonymous_headers() -> Dict[str, str]:
"""Build GitHub HTTP headers for unauthenticated calls."""
return {'Accept': 'application/vnd.github.v3+json', 'User-Agent': 'gittensor-cli'}


_session_cache: Optional[Dict[str, requests.Session]] = None


@contextmanager
def session_scope() -> Iterator[None]:
"""Share requests.Session per PAT within the scope; close all on exit. Not re-entrant."""
global _session_cache
if _session_cache is not None:
raise RuntimeError('session_scope is not re-entrant')
_session_cache = {}
try:
yield
finally:
cache = _session_cache
_session_cache = None
for s in cache.values():
s.close()


def _build_session(token: str) -> requests.Session:
session = requests.Session()
session.headers.update(make_headers(token) if token else make_anonymous_headers())
return session


def get_session(token: str) -> requests.Session:
"""Return a requests.Session for the given PAT, reusing one within a session_scope() and allocating fresh otherwise."""
if _session_cache is None:
return _build_session(token)
key = token or ''
if key not in _session_cache:
_session_cache[key] = _build_session(token)
return _session_cache[key]


def get_github_id(token: str) -> Optional[str]:
Expand All @@ -189,12 +234,12 @@ def get_github_id(token: str) -> Optional[str]:
if not token:
return None

headers = make_headers(token)
session = get_session(token)

# Retry logic for timeout issues
for attempt in range(6):
try:
response = requests.get(f'{BASE_GITHUB_API_URL}/user', headers=headers, timeout=GITHUB_HTTP_TIMEOUT_SECONDS)
response = session.get(f'{BASE_GITHUB_API_URL}/user', timeout=GITHUB_HTTP_TIMEOUT_SECONDS)
if response.status_code == 200:
try:
user_data: Dict[str, Any] = response.json()
Expand Down Expand Up @@ -235,14 +280,13 @@ def get_merge_base_sha(repository: str, base_sha: str, head_sha: str, token: str
Returns:
Merge-base commit SHA, or None if the request fails
"""
headers = make_headers(token)
session = get_session(token)
max_attempts = 3

for attempt in range(max_attempts):
try:
response = requests.get(
response = session.get(
f'{BASE_GITHUB_API_URL}/repos/{repository}/compare/{base_sha}...{head_sha}',
headers=headers,
timeout=15,
)

Expand Down Expand Up @@ -293,7 +337,7 @@ def get_pull_request_file_changes(repository: str, pr_number: int, token: str) -
"""
max_attempts = 3
per_page = 100
headers = make_headers(token)
session = get_session(token)

all_file_diffs: list = []
page = 1
Expand All @@ -302,9 +346,8 @@ def get_pull_request_file_changes(repository: str, pr_number: int, token: str) -

while attempt < max_attempts:
try:
response = requests.get(
response = session.get(
f'{BASE_GITHUB_API_URL}/repos/{repository}/pulls/{pr_number}/files',
headers=headers,
params={'per_page': per_page, 'page': page},
timeout=15,
)
Expand Down Expand Up @@ -485,20 +528,15 @@ def _search_issue_referencing_prs_rest(
if issue_number < 1:
return []

if token:
headers = make_headers(token)
else:
headers = {'Accept': 'application/vnd.github.v3+json'}
headers.setdefault('User-Agent', 'gittensor-cli')
session = get_session(token or '')

state_clause = f' state:{state}' if state != 'all' else ''
max_attempts = 3
for attempt in range(max_attempts):
try:
resp = requests.get(
resp = session.get(
f'{BASE_GITHUB_API_URL}/search/issues',
params={'q': f'repo:{repo} type:pr{state_clause} {issue_number} in:title,body', 'per_page': '50'},
headers=headers,
timeout=10,
)
resp.raise_for_status()
Expand Down Expand Up @@ -593,11 +631,12 @@ def execute_graphql_query(
Returns:
Parsed JSON response data, or None if all attempts failed
"""
session = get_session(token)
headers = make_graphql_headers(token)

for attempt in range(max_attempts):
try:
response = requests.post(
response = session.post(
f'{BASE_GITHUB_API_URL}/graphql',
headers=headers,
json={'query': query, 'variables': variables},
Expand Down Expand Up @@ -670,6 +709,7 @@ def get_github_graphql_query(
"""

max_attempts = 8
session = get_session(token)
headers = make_graphql_headers(token)
limit = page_size if page_size is not None else min(100, max_prs - merged_pr_count)

Expand All @@ -681,7 +721,7 @@ def get_github_graphql_query(
'maxChangesRequestedReviews': _MAX_CHANGES_REQUESTED_REVIEWS,
}
try:
response = requests.post(
response = session.post(
f'{BASE_GITHUB_API_URL}/graphql',
headers=headers,
json={'query': QUERY, 'variables': variables},
Expand Down Expand Up @@ -1068,12 +1108,11 @@ def check_github_issue_closed(repo: str, issue_number: int, token: str) -> Optio
Returns:
Dict with 'is_closed', 'solver_github_id', 'pr_number', 'solver_lookup_failed' or None on error
"""
headers = make_headers(token)
session = get_session(token)

try:
response = requests.get(
response = session.get(
f'{BASE_GITHUB_API_URL}/repos/{repo}/issues/{issue_number}',
headers=headers,
timeout=15,
)

Expand Down
9 changes: 9 additions & 0 deletions gittensor/utils/mirror/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ def __init__(
self.max_attempts = max_attempts
self.session = session or requests.Session()

def close(self) -> None:
self.session.close()

def __enter__(self) -> 'MirrorClient':
return self

def __exit__(self, exc_type, exc, tb) -> None:
self.close()

def get_miner_pulls(
self,
github_id: str,
Expand Down
13 changes: 8 additions & 5 deletions gittensor/validator/oss_contributions/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np

from gittensor.classes import MinerEvaluation
from gittensor.utils.github_api_tools import load_miners_prs
from gittensor.utils.github_api_tools import load_miners_prs, session_scope
from gittensor.utils.mirror.client import MirrorClient
from gittensor.validator import pat_storage
from gittensor.validator.oss_contributions.inspections import (
detect_and_penalize_miners_sharing_github,
Expand Down Expand Up @@ -73,17 +74,19 @@ async def evaluate_miners_pull_requests(
}

if legacy_repos:
load_miners_prs(miner_eval, legacy_repos)
score_miner_prs(miner_eval, legacy_repos, programming_languages, token_config)
with session_scope():
load_miners_prs(miner_eval, legacy_repos)
score_miner_prs(miner_eval, legacy_repos, programming_languages, token_config)

if mirror_repos:
mirror_eval = MirrorMinerEvaluation(
uid=miner_eval.uid,
hotkey=miner_eval.hotkey,
github_id=miner_eval.github_id,
)
load_mirror_miner_prs(mirror_eval, mirror_repos)
score_mirror_miner_prs(mirror_eval, mirror_repos, programming_languages, token_config)
with MirrorClient() as mirror_client:
load_mirror_miner_prs(mirror_eval, mirror_repos, client=mirror_client)
score_mirror_miner_prs(mirror_eval, mirror_repos, programming_languages, token_config, client=mirror_client)
combine(miner_eval, mirror_eval)

# Clear PAT after scoring to avoid storing sensitive data in memory
Expand Down
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# The MIT License (MIT)
# Copyright © 2025 Entrius

"""
Root pytest configuration shared across all test directories.
"""

import pytest
import requests


class _ForwardingSession:
"""Session-like proxy delegating .get/.post to requests.get/post at call time.
Preserves compatibility with tests that @patch requests.get/post while production uses requests.Session via get_session().
"""

def __init__(self):
self.headers = {}

def get(self, *args, **kwargs):
if 'headers' not in kwargs:
kwargs['headers'] = dict(self.headers)
return requests.get(*args, **kwargs)

def post(self, *args, **kwargs):
if 'headers' not in kwargs:
kwargs['headers'] = dict(self.headers)
return requests.post(*args, **kwargs)


@pytest.fixture(autouse=True)
def _forward_github_sessions(monkeypatch):
"""Replace get_session with a forwarding proxy so @patch(requests.get/post) still works."""
try:
from gittensor.utils import github_api_tools
except ImportError:
yield
return

def _forwarding_get_session(token):
session = _ForwardingSession()
headers = github_api_tools.make_headers(token) if token else github_api_tools.make_anonymous_headers()
session.headers.update(headers)
return session

monkeypatch.setattr(github_api_tools, 'get_session', _forwarding_get_session)
yield
80 changes: 80 additions & 0 deletions tests/utils/test_github_api_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,5 +1586,85 @@ def test_falls_back_to_base_ref_oid_when_merge_base_fails(self, mock_merge_base,
assert call_args[0][2] == 'base_branch_tip_sha', 'Should fall back to base_ref_oid'


# ============================================================================
# session_scope tests
# ============================================================================


# Bind to the unpatched implementations at import time so the autouse forwarding
# fixture in tests/conftest.py (which swaps out github_api_tools.get_session) does
# not interfere with these tests of the real caching behavior.
_real_session_scope = github_api_tools.session_scope
_real_get_session = github_api_tools.get_session


class TestSessionScope:
"""Direct tests for session_scope cache lifecycle."""

def test_same_session_reused_for_same_token_in_scope(self):
with _real_session_scope():
s1 = _real_get_session('tokenA')
s2 = _real_get_session('tokenA')
assert s1 is s2

def test_distinct_sessions_per_token(self):
with _real_session_scope():
s_a = _real_get_session('tokenA')
s_b = _real_get_session('tokenB')
s_anon = _real_get_session('')
assert s_a is not s_b
assert s_a is not s_anon
assert s_b is not s_anon

def test_outside_scope_returns_fresh_session(self):
s1 = _real_get_session('tokenA')
s2 = _real_get_session('tokenA')
try:
assert s1 is not s2
finally:
s1.close()
s2.close()

def test_re_entrance_raises(self):
with _real_session_scope():
with pytest.raises(RuntimeError, match='not re-entrant'):
with _real_session_scope():
pass

def test_sessions_closed_on_exit(self, monkeypatch):
built: list = []

def fake_build_session(token):
session = Mock()
session.headers = {}
built.append(session)
return session

monkeypatch.setattr(github_api_tools, '_build_session', fake_build_session)

with _real_session_scope():
_real_get_session('tokenA')
_real_get_session('tokenB')

assert len(built) == 2
for session in built:
session.close.assert_called_once()

def test_cache_cleared_on_exit_even_when_close_raises(self, monkeypatch):
def fake_build_session(token):
session = Mock()
session.headers = {}
session.close.side_effect = RuntimeError('boom')
return session

monkeypatch.setattr(github_api_tools, '_build_session', fake_build_session)

with pytest.raises(RuntimeError):
with _real_session_scope():
_real_get_session('tokenA')

assert github_api_tools._session_cache is None


if __name__ == '__main__':
pytest.main([__file__, '-v'])
Loading