Skip to content
Closed
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ data/transformed/uniprot_functional_microbes/uniprot_kgx.zip

data/transformed/ontologies/.Rapp.history
mediadive_cache.sqlite
mediadive_bulk_cache.sqlite
data/raw/mediadive/
data/raw/.keep
kg_microbe/transform_utils/uniprot_human/tmp/relevant_files.tsv
kg_microbe/transform_utils/uniprot/tmp/relevant_files.tsv
Expand Down
11 changes: 6 additions & 5 deletions kg_microbe/transform_utils/metatraits/metatraits.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def _get_ncbitaxon_adapter():
try:
print(f" Creating symlink to cached database: {local_db} -> {oak_cache}")
local_db.symlink_to(oak_cache)
print(f" Using cached NCBITaxon database from OAK")
print(" Using cached NCBITaxon database from OAK")
return get_adapter(f"sqlite:{local_db}")
except Exception as e:
print(f" Failed to create symlink ({e}), using remote adapter")
Expand All @@ -149,7 +149,7 @@ def _get_ncbitaxon_adapter():
try:
local_db.symlink_to(oak_cache)
print(f" Created symlink for future use: {local_db}")
except Exception:
except Exception: # noqa: S110
pass # Symlink creation is optional, don't fail if it doesn't work

return adapter
Expand Down Expand Up @@ -200,7 +200,7 @@ def _process_file_worker(args: Tuple[Path, Path, Dict[str, Any], bool]) -> Dict[
if hasattr(transform._ncbi_adapter, 'engine') and transform._ncbi_adapter.engine is not None:
transform._ncbi_adapter.engine.dispose()
transform._ncbi_adapter = None
except Exception:
except Exception: # noqa: S110
pass # Ignore cleanup errors


Expand Down Expand Up @@ -1421,7 +1421,8 @@ def run(
self._run_parallel(input_files, show_status, self.num_workers)
elif use_mp and len(input_files) == 1:
# Single file: split into chunks for parallel processing
print(f" Using parallel chunked processing (splitting 1 file across {self.num_workers or 'auto'} workers)")
workers_desc = self.num_workers or 'auto'
print(f" Using parallel chunked processing (splitting 1 file across {workers_desc} workers)")
self._run_parallel_chunked(input_files[0], show_status, self.num_workers)
else:
# No multiprocessing: sequential
Expand All @@ -1433,5 +1434,5 @@ def run(
if hasattr(self._ncbi_adapter, 'engine') and self._ncbi_adapter.engine is not None:
self._ncbi_adapter.engine.dispose()
self._ncbi_adapter = None
except Exception:
except Exception: # noqa: S110
pass # Ignore cleanup errors
164 changes: 139 additions & 25 deletions kg_microbe/utils/mediadive_bulk_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@

import json
import logging
import threading
import time
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Dict, List
from typing import Dict, List, Optional

import requests
import requests_cache
from tqdm import tqdm

# Default to 5 workers — still a large speedup over sequential but polite to
# MediaDive, which is a small academic REST API at DSMZ.
DEFAULT_MAX_WORKERS = 5

# Descriptive User-Agent so the API operator can identify traffic source.
USER_AGENT = "kg-microbe (Knowledge-Graph-Hub; https://github.com/Knowledge-Graph-Hub/kg-microbe)"

# Set up logging for API warnings (written to file, not stdout)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,32 +51,61 @@ def setup_cache(cache_name: str = "mediadive_bulk_cache"):
print(f"HTTP cache enabled: {cache_name}.sqlite")


def get_json_from_api(url: str, retry_count: int = 3, retry_delay: float = 2.0, verbose: bool = False) -> Dict:
def _make_session() -> requests.Session:
"""Create a requests Session with the kg-microbe User-Agent."""
session = requests.Session()
session.headers.update({"User-Agent": USER_AGENT})
return session


def get_json_from_api(
url: str,
retry_count: int = 3,
retry_delay: float = 2.0,
verbose: bool = False,
session: Optional[requests.Session] = None,
) -> Dict:
"""
Get JSON data from MediaDive API with retry logic.

Respects Retry-After headers on 429 responses.

Args:
----
url: Full API URL to fetch
retry_count: Number of retries on failure
retry_delay: Delay in seconds between retries
retry_delay: Delay in seconds between retries (overridden by Retry-After on 429)
verbose: If True, log empty responses (useful for debugging)
session: Optional requests Session to reuse (uses module-level session if None)

Returns:
-------
Dictionary with API response data (empty dict on failure or empty response)

"""
requester = session or _make_session()
for attempt in range(retry_count):
Comment thread
turbomam marked this conversation as resolved.
Outdated
try:
r = requests.get(url, timeout=30)
r = requester.get(url, timeout=30)
r.raise_for_status()
data_json = r.json()
result = data_json.get(DATA_KEY, {})
# Distinguish empty API response from failure (for debugging)
if not result and verbose:
print(f" Empty response from API: {url}")
return result
except requests.exceptions.HTTPError as e:
if e.response is not None and e.response.status_code == 429:
wait = float(e.response.headers.get("Retry-After", retry_delay))
logger.debug(f"429 Too Many Requests — waiting {wait}s (URL: {url})")
time.sleep(wait)
continue
Comment thread
turbomam marked this conversation as resolved.
Outdated
Comment thread
turbomam marked this conversation as resolved.
Outdated
if attempt < retry_count - 1:
logger.debug(f"Retry {attempt + 1}/{retry_count} after error: {e} (URL: {url})")
time.sleep(retry_delay)
else:
logger.warning(f"Request failed after {retry_count} attempts: {e} (URL: {url})")
return {}
except requests.exceptions.RequestException as e:
if attempt < retry_count - 1:
logger.debug(f"Retry {attempt + 1}/{retry_count} after error: {e} (URL: {url})")
Expand Down Expand Up @@ -99,55 +137,117 @@ def load_basic_media_list(basic_file: str) -> List[Dict]:
return media_list


def download_detailed_media(media_list: List[Dict]) -> Dict[str, Dict]:
def _fetch_medium_detail(
medium: Dict,
session: requests.Session,
rate_limiter: threading.Semaphore,
retry_count: int,
retry_delay: float,
) -> tuple[str, dict]:
"""Fetch detailed recipe for a single medium. Returns (medium_id, data)."""
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_ENDPOINT + medium_id
with rate_limiter:
return medium_id, get_json_from_api(url, retry_count=retry_count, retry_delay=retry_delay, session=session)


def _fetch_medium_strains(
medium: Dict,
session: requests.Session,
rate_limiter: threading.Semaphore,
retry_count: int,
retry_delay: float,
) -> tuple[str, dict]:
"""Fetch strain associations for a single medium. Returns (medium_id, data)."""
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_STRAINS_ENDPOINT + medium_id
with rate_limiter:
return medium_id, get_json_from_api(url, retry_count=retry_count, retry_delay=retry_delay, session=session)

Comment thread
turbomam marked this conversation as resolved.

def download_detailed_media(
media_list: List[Dict],
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
requests_per_second: float = 10.0,
) -> Dict[str, Dict]:
Comment thread
turbomam marked this conversation as resolved.
"""
Download detailed recipe information for all media.

Args:
----
media_list: List of basic media records
max_workers: Number of parallel download threads
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)
requests_per_second: Maximum sustained request rate (smooths bursts)

Returns:
-------
Dictionary mapping medium_id -> detailed_recipe_data

"""
print(f"\nDownloading detailed recipes for {len(media_list)} media...")
detailed_data = {}

for medium in tqdm(media_list, desc="Downloading medium details"):
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_ENDPOINT + medium_id
data = get_json_from_api(url)
if data:
detailed_data[medium_id] = data
detailed_data: Dict[str, Dict] = {}
session = _make_session()
rate_limiter = threading.Semaphore(max_workers)

def fetch(medium: Dict) -> tuple[str, dict]:
return _fetch_medium_detail(medium, session, rate_limiter, retry_count, retry_delay)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for medium_id, data in tqdm(
executor.map(fetch, media_list),
Comment thread
turbomam marked this conversation as resolved.
total=len(media_list),
desc="Downloading medium details",
):
if data:
detailed_data[medium_id] = data

print(f"Downloaded {len(detailed_data)} detailed medium recipes")
return detailed_data


def download_medium_strains(media_list: List[Dict]) -> Dict[str, List]:
def download_medium_strains(
media_list: List[Dict],
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
requests_per_second: float = 10.0,
) -> Dict[str, List]:
Comment thread
turbomam marked this conversation as resolved.
"""
Download strain associations for all media.

Args:
----
media_list: List of basic media records
max_workers: Number of parallel download threads
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)
requests_per_second: Maximum sustained request rate (smooths bursts)

Returns:
-------
Dictionary mapping medium_id -> list_of_strain_data

"""
print(f"\nDownloading strain associations for {len(media_list)} media...")
strain_data = {}

for medium in tqdm(media_list, desc="Downloading medium-strain associations"):
medium_id = str(medium.get(ID_KEY))
url = MEDIADIVE_REST_API_BASE_URL + MEDIUM_STRAINS_ENDPOINT + medium_id
data = get_json_from_api(url)
if data:
strain_data[medium_id] = data
strain_data: Dict[str, List] = {}
session = _make_session()
rate_limiter = threading.Semaphore(max_workers)

def fetch(medium: Dict) -> tuple[str, dict]:
return _fetch_medium_strains(medium, session, rate_limiter, retry_count, retry_delay)

with ThreadPoolExecutor(max_workers=max_workers) as executor:
for medium_id, data in tqdm(
executor.map(fetch, media_list),
Comment thread
turbomam marked this conversation as resolved.
total=len(media_list),
desc="Downloading medium-strain associations",
):
if data:
strain_data[medium_id] = data

# Count total strain associations, handling different data types
total_strains = 0
Expand Down Expand Up @@ -235,7 +335,13 @@ def save_json_file(data: Dict, filepath: Path, description: str):
print(f"Saved {description} to {filepath} ({file_size_mb:.2f} MB)")


def download_mediadive_bulk(basic_file: str, output_dir: str):
def download_mediadive_bulk(
basic_file: str,
output_dir: str,
max_workers: int = DEFAULT_MAX_WORKERS,
retry_count: int = 3,
retry_delay: float = 2.0,
):
"""
Download all MediaDive data in bulk.

Expand All @@ -245,6 +351,9 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):
----
basic_file: Path to mediadive.json (basic media list)
output_dir: Directory to save bulk data files
max_workers: Number of parallel download threads (default: 5, polite for small APIs)
retry_count: Number of retries on request failure
retry_delay: Seconds between retries (overridden by Retry-After on 429)

"""
output_path = Path(output_dir)
Expand All @@ -261,6 +370,7 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):
logger.setLevel(logging.DEBUG)
logger.propagate = False # Prevent propagation to root logger and stdout
print(f"API warnings will be logged to: {log_file}")
print(f"Using {max_workers} parallel workers")

# Set up HTTP caching
setup_cache()
Expand All @@ -271,12 +381,16 @@ def download_mediadive_bulk(basic_file: str, output_dir: str):

# Step 2: Download detailed medium recipes
print("\n[2/5] Downloading detailed medium recipes...")
detailed_media = download_detailed_media(media_list)
detailed_media = download_detailed_media(
media_list, max_workers=max_workers, retry_count=retry_count, retry_delay=retry_delay
)
save_json_file(detailed_media, output_path / "media_detailed.json", "detailed media recipes")

# Step 3: Download medium-strain associations
print("\n[3/5] Downloading medium-strain associations...")
media_strains = download_medium_strains(media_list)
media_strains = download_medium_strains(
media_list, max_workers=max_workers, retry_count=retry_count, retry_delay=retry_delay
)
save_json_file(media_strains, output_path / "media_strains.json", "medium-strain associations")

# Step 4: Extract solutions from embedded structure
Expand Down
2 changes: 1 addition & 1 deletion kg_microbe/utils/robot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def remove_convert_to_json(path: str, ont_name: str, terms: Union[List, Path]):
if isinstance(terms, List):
terms_param = [
item
for sublist in zip(["--term"] * len(terms), terms)
for sublist in zip(["--term"] * len(terms), terms, strict=False)
for item in sublist # noqa
Comment thread
turbomam marked this conversation as resolved.
Outdated
]
call = [
Expand Down
Loading
Loading