Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,23 +1,54 @@
# Stage 1: Build environment
FROM python:3.12-slim AS core
FROM python:3.11-slim AS core

# Add git
RUN apt-get update && apt-get install -y git build-essential pkg-config libhdf5-dev
# Install system dependencies (git, build-essential, etc.)
# We add --no-install-recommends and clean up apt lists to keep the image slim.
RUN apt-get update && apt-get install -y \
git \
build-essential \
pkg-config \
libhdf5-dev \
--no-install-recommends && \
rm -rf /var/lib/apt/lists/*

# Add uv and use the system python (no need to make venv)
USER root
COPY --from=ghcr.io/astral-sh/uv:0.5.4 /uv /bin/uv
ENV UV_SYSTEM_PYTHON=1
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

# Use the system-wide Python (no venv needed in container)
ENV UV_SYSTEM_PYTHON=1
WORKDIR /home/app/mava

COPY . .
# Build-time argument to control GPU installation.
# Build with: docker build --build-arg USE_CUDA=true -t my-image .
ARG USE_CUDA=false

RUN uv pip install -e .
# --- Dependency Installation Layer ---
# Copy only the file needed for dependency resolution
COPY pyproject.toml .
COPY uv.lock .

ARG USE_CUDA=false
RUN if [ "$USE_CUDA" = true ] ; \
then uv pip install jax[cuda12]==0.4.30 ; \
# Install all dependencies *except* the local project
# We use a shell variable (JAX_EXTRA) to conditionally add the 'cuda12' extra
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$USE_CUDA" = true ] ; then \
uv sync --locked --no-install-project --extra cuda12 ; \
else \
uv sync --locked --no-install-project ; \
fi

# --- Application Code Layer ---
# This layer is cached and only re-runs if Mava code changes.

# Copy all the application source code
COPY . .

# Install the local project itself
# We pass the JAX_EXTRA variable again to ensure the
# full dependency set (including the project) is synced correctly.
RUN --mount=type=cache,target=/root/.cache/uv \
if [ "$USE_CUDA" = true ] ; then \
uv sync --locked --extra cuda12 ; \
else \
uv sync --locked ; \
fi

# Expose Tensorboard port
EXPOSE 6006
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ build:
DOCKER_BUILDKIT=1 docker build --build-arg USE_CUDA=$(USE_CUDA) --tag $(IMAGE) .

run:
$(DOCKER_RUN) python $(example)
$(DOCKER_RUN) uv run $(example)

bash:
$(DOCKER_RUN) bash
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ dependencies = [
"numpy==1.26.4",
"omegaconf",
"optax",
"orbax-checkpoint==0.11.20",
"protobuf~=3.20",
"rware",
"scipy==1.12.0",
Expand Down
20 changes: 16 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.