Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
.gitmodules

.conda
.neptune
mlruns
wandb
.pytest_cache
.mypy_cache
.ruff_cache
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ results/
# Hydra logs
outputs/
multirun/
.neptune
mlruns/
wandb/

# AI
.aider**
32 changes: 24 additions & 8 deletions mava/configs/logger/logger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,36 @@ loggers:
console:
_target_: mava.utils.logger.ConsoleLogger
enabled: True
neptune:
_target_: mava.utils.logger.NeptuneLogger
mlflow:
_target_: mava.utils.logger.MLflowLogger
enabled: True

# If specified will resume the run with this run ID.
# This is useful for resuming runs and logging from multiple processes.
# NOTE: it will overwrite the run unless you set the timestep correctly.
run_id: ~
experiment_name: DAIR-Research/mava
tags:
purpose: delete
detailed_logging: False # having mean/std/min/max can clutter mlflow so we make it optional
# Whether JSON file data should be uploaded to MLflow for downstream
# aggregation and plotting of data from multiple experiments. Note that when uploading JSON files,
# `json.path` must be unset to ensure that uploaded json files don't continue getting larger
# over time. Setting both will raise an error.
upload_json_data: False
wandb:
_target_: mava.utils.logger.WandBLogger
enabled: False

# If specified will resume the run with this run ID.
# This is useful for resuming runs and logging from multiple processes.
# NOTE: it will overwrite the run unless you set the timestep correctly.
run_id: ~
project: Instadeep/mava-benchmark
tag: [delete]
group_tag: [delete]
detailed_logging: False # having mean/std/min/max can clutter neptune so we make it optional
architecture_name: ${arch.architecture_name} # this is required because async logging causes deadlocks in sebulba
# Whether JSON file data should be uploaded to Neptune for downstream
project: mava-benchmark
entity: ~
tags: [delete]
detailed_logging: False # having mean/std/min/max can clutter wandb so we make it optional
# Whether JSON file data should be uploaded to WandB for downstream
# aggregation and plotting of data from multiple experiments. Note that when uploading JSON files,
# `json.path` must be unset to ensure that uploaded json files don't continue getting larger
# over time. Setting both will raise an error.
Expand Down
163 changes: 127 additions & 36 deletions mava/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@

import hydra
import jax
import neptune
import mlflow
import numpy as np
import wandb
from colorama import Fore, Style
from etils.epath import Path
from jax import tree
from jax.typing import ArrayLike
from marl_eval.json_tools import JsonLogger as MarlEvalJsonLogger
from neptune.utils import stringify_unsupported
from omegaconf import DictConfig, OmegaConf
from pandas.io.json._normalize import _simple_json_normalize as flatten_dict
from rich.pretty import pprint
Expand Down Expand Up @@ -209,49 +209,47 @@ def stop(self) -> None:
logger.stop()


class NeptuneLogger(BaseLogger):
class MLflowLogger(BaseLogger):
def __init__(
self,
base_exp_path: PathLike,
unique_token: str,
system_name: str,
project: str,
tag: list[str],
group_tag: list[str],
experiment_name: str,
tags: dict[str, str],
detailed_logging: bool,
architecture_name: str,
upload_json_data: bool,
run_id: str | None = None,
) -> None:
"""
Initialize neptune.ai logger for experiment tracking.
Initialize MLflow logger for experiment tracking.

Args:
base_exp_path: Base path where all logs are stored.
unique_token: Unique identifier string for this run.
system_name: Name of the system/algorithm being logged.
project: neptune.ai project name.
tag: List of tags for the neptune.ai experiment.
group_tag: List of group tags - useful for keeping track of a group of experiments.
experiment_name: MLflow experiment name.
tag: Dictionary of tags for the MLflow run.
detailed_logging: Whether to log detailed metrics (incl. std/min/max).
architecture_name: Name of the architecture [anakin | sebulba].
upload_json_data: Whether to upload JSON data to neptune.ai.
upload_json_data: Whether to upload JSON data to MLflow.
run_id: ID of the run you wish to resume - None if you don't want to resume the run.
Note this will overwrite the run if you restart the step from 0.
"""
# async logging leads to deadlocks in sebulba
mode = "async" if architecture_name == "anakin" else "sync"
tracking_uri = os.environ.get("MLFLOW_TRACKING_URI")
if not tracking_uri:
raise ValueError("MLFLOW_TRACKING_URI environment variable is required")
mlflow.set_tracking_uri(tracking_uri)
mlflow.set_experiment(experiment_name)

if run_id is not None:
self.logger = neptune.init_run(with_id=run_id, project=project, mode=mode)
self.logger = mlflow.start_run(run_id=run_id)
else:
self.logger = neptune.init_run(project=project, tags=list(tag), mode=mode)
self.logger["sys/group_tags"].add(list(group_tag))
self.logger = mlflow.start_run(tags=tags)

self.detailed_logging = detailed_logging
self.upload_json_data = upload_json_data

# Store json path for uploading json data to Neptune.
# Store json path for uploading json data to MLflow.
json_exp_path = get_logger_path(system_name, "json")
self.json_file_path = Path(base_exp_path, json_exp_path, unique_token, "metrics.json")
self.unique_token = unique_token
Expand All @@ -265,15 +263,110 @@ def log_stat(self, key: str, value: float, step: int, eval_step: int, event: Log
return

value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value
self.logger[f"{event.value}/{key}"].log(value, step=step)
mlflow.log_metric(f"{event.value}/{key}", value, step=step)

def log_dict(self, data: Metrics, step: int, eval_step: int, event: LogEvent) -> None:
flat = flatten_dict(data, sep="/")
batch: dict[str, float] = {}
for key, value in flat.items():
is_main_metric = "/" not in key or key.endswith("/mean")
if not self.detailed_logging and not is_main_metric:
continue
value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value
batch[f"{event.value}/{key}"] = value
if batch:
mlflow.log_metrics(batch, step=step)

def log_config(self, config: Dict) -> None:
self.logger["config"] = stringify_unsupported(config)
flat = flatten_dict(config, sep="/")
mlflow.log_params({k: str(v) for k, v in flat.items()})

def stop(self) -> None:
if self.upload_json_data:
self._zip_and_upload_json()
self.logger.stop()
mlflow.end_run()

def _zip_and_upload_json(self) -> None:
# Create the zip file path by replacing '.json' with '.zip'
zip_file_path = self.json_file_path.with_suffix(".zip").as_posix()

# Create a zip file containing the specified JSON file
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
zipf.write(self.json_file_path, arcname=self.json_file_path.name)

mlflow.log_artifact(zip_file_path, artifact_path=f"metrics/metrics_{self.unique_token}")


class WandBLogger(BaseLogger):
def __init__(
self,
base_exp_path: PathLike,
unique_token: str,
system_name: str,
project: str,
tags: list[str],
detailed_logging: bool,
upload_json_data: bool,
run_id: str | None = None,
entity: str | None = None,
) -> None:
"""
Initialize WandB logger for experiment tracking.

Args:
base_exp_path: Base path where all logs are stored.
unique_token: Unique identifier string for this run.
system_name: Name of the system/algorithm being logged.
project: WandB project name.
tags: List of tags for the WandB run.
detailed_logging: Whether to log detailed metrics (incl. std/min/max).
upload_json_data: Whether to upload JSON data to WandB.
run_id: ID of the run you wish to resume - None if you don't want to resume the run.
Note this will overwrite the run unless you set the timestep correctly.
entity: WandB entity (user or team). None uses the default from the env.
"""
if run_id is not None:
self.logger = wandb.init(
project=project, entity=entity, id=run_id, resume="allow"
)
else:
self.logger = wandb.init(project=project, entity=entity, tags=list(tags))

self.detailed_logging = detailed_logging
self.upload_json_data = upload_json_data

# Store json path for uploading json data to WandB.
json_exp_path = get_logger_path(system_name, "json")
self.json_file_path = Path(base_exp_path, json_exp_path, unique_token, "metrics.json")
self.unique_token = unique_token

def log_stat(self, key: str, value: float, step: int, eval_step: int, event: LogEvent) -> None:
is_main_metric = "/" not in key or key.endswith("/mean")
if not self.detailed_logging and not is_main_metric:
return

value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value
wandb.log({f"{event.value}/{key}": value}, step=step)

def log_dict(self, data: Metrics, step: int, eval_step: int, event: LogEvent) -> None:
flat = flatten_dict(data, sep="/")
batch: dict[str, float] = {}
for key, value in flat.items():
is_main_metric = "/" not in key or key.endswith("/mean")
if not self.detailed_logging and not is_main_metric:
continue
value = value.item() if isinstance(value, (jax.Array, np.ndarray)) else value
batch[f"{event.value}/{key}"] = value
if batch:
wandb.log(batch, step=step)

def log_config(self, config: Dict) -> None:
wandb.config.update(config)

def stop(self) -> None:
if self.upload_json_data:
self._zip_and_upload_json()
wandb.finish()

def _zip_and_upload_json(self) -> None:
# Create the zip file path by replacing '.json' with '.zip'
Expand All @@ -283,7 +376,7 @@ def _zip_and_upload_json(self) -> None:
with zipfile.ZipFile(zip_file_path, "w", zipfile.ZIP_DEFLATED) as zipf:
zipf.write(self.json_file_path, arcname=self.json_file_path.name)

self.logger[f"metrics/metrics_{self.unique_token}"].upload(zip_file_path)
wandb.save(zip_file_path, policy="now")


class TensorboardLogger(BaseLogger):
Expand Down Expand Up @@ -437,19 +530,17 @@ def _make_multi_logger(cfg: DictConfig) -> MultiLogger:
"""Instantiate only enabled loggers and remove the 'enabled' flag."""
unique_token = datetime.now().strftime("%Y%m%d%H%M%S")

if (
cfg.logger.loggers.neptune.enabled
and cfg.logger.loggers.json.enabled
and cfg.logger.loggers.neptune.upload_json_data
and cfg.logger.loggers.json.path
):
raise ValueError(
"Cannot upload json data to Neptune when `json_path` is set in the base logger config. "
"This is because each subsequent run will create a larger json file which will use "
"unnecessary storage. Either set `upload_json_data: false` if you don't want to "
"upload your json data but store a large file locally or set `json_path: ~` in "
"the base logger config."
)
if cfg.logger.loggers.json.enabled and cfg.logger.loggers.json.path:
for name in ("mlflow", "wandb"):
remote_cfg = cfg.logger.loggers[name]
if remote_cfg.enabled and remote_cfg.upload_json_data:
raise ValueError(
f"Cannot upload json data to {name} when `json.path` is set in the base "
"logger config. This is because each subsequent run will create a larger "
"json file which will use unnecessary storage. Either set "
"`upload_json_data: false` if you don't want to upload your json data but "
"store a large file locally or set `json.path: ~` in the base logger config."
)
loggers: List[BaseLogger] = []
for _logger_config in cfg.logger.loggers.values():
logger_config = dict(_logger_config) # Create a copy to avoid modifying the original
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ dependencies = [
"jumanji>= 1.1.1",
"lbforaging",
"matrax>= 0.0.5",
"mlflow",
"mujoco==3.1.3",
"mujoco-mjx==3.1.3",
"neptune",
"numpy==1.26.4",
"omegaconf",
"optax",
Expand All @@ -66,6 +66,7 @@ dependencies = [
"tensorboard_logger",
"tensorflow_probability",
"type_enforced", # needed because gigastep is missing this dependency
"wandb",
]

[project.optional-dependencies]
Expand Down
7 changes: 4 additions & 3 deletions test/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ def _run_system(system_name: str, cfg: DictConfig) -> float:
"""Runs a system."""
OmegaConf.set_struct(cfg, False)
# we never want to log these tests anywhere
cfg.logger.use_neptune = False
cfg.logger.use_tb = False
cfg.logger.use_json = False
cfg.logger.loggers.mlflow.enabled = False
cfg.logger.loggers.wandb.enabled = False
cfg.logger.loggers.tensorboard.enabled = False
cfg.logger.loggers.json.enabled = False

system = importlib.import_module(f"mava.systems.{system_name}")
eval_perf = system.run_experiment(cfg)
Expand Down
Loading
Loading