Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
a4ca6a2
Add JAX release support v0.10.0
erman-gurses Jun 23, 2026
dd528cb
Fix gfx_arch: all
erman-gurses Jun 23, 2026
baef7fe
Remove tarball
erman-gurses Jun 23, 2026
6fc71c7
Test
erman-gurses Jun 23, 2026
74ab4d7
Test
erman-gurses Jun 23, 2026
19eda7c
Add back tarball for 0.9.0
erman-gurses Jun 23, 2026
c170950
Test
erman-gurses Jun 23, 2026
e6f750d
Test
erman-gurses Jun 23, 2026
b9a4640
Test
erman-gurses Jun 24, 2026
700e043
Test
erman-gurses Jun 24, 2026
ec7f744
Test
erman-gurses Jun 24, 2026
6c8f879
Test
erman-gurses Jun 24, 2026
785787d
Test
erman-gurses Jun 24, 2026
1c8dc98
Test
erman-gurses Jun 24, 2026
1f13220
Use requirements-jax.txt
erman-gurses Jun 24, 2026
41c015b
Test
erman-gurses Jun 24, 2026
099878f
Remove condition for determine_version.py
erman-gurses Jun 24, 2026
88fc90d
Add options for build_mode
erman-gurses Jun 24, 2026
22c916c
Add default for jax_repository
erman-gurses Jun 24, 2026
8b6f2f5
Update gfx_arch description
erman-gurses Jun 24, 2026
9d58a73
Move repository and ref inputs to the bottom
erman-gurses Jun 24, 2026
084898b
Move jax_repositor,y build_mode, and gfx_arch inputs up
erman-gurses Jun 24, 2026
382f39e
Solve --find_links inconsistency use index_url
erman-gurses Jun 24, 2026
dcf62ab
Add package_index_url output for better ROCm package indexing
erman-gurses Jun 25, 2026
50c652d
Add unit test
erman-gurses Jun 25, 2026
9b1e7ed
Increase shared memory size
erman-gurses Jun 25, 2026
ff1f03f
Reduce memory pressure
erman-gurses Jun 25, 2026
f7448a4
Test memory
erman-gurses Jun 25, 2026
363b89c
Test
erman-gurses Jun 25, 2026
0bdd100
Test
erman-gurses Jun 25, 2026
3d46f1b
Get original testing script
erman-gurses Jun 25, 2026
c15d36d
Update README.md
erman-gurses Jun 25, 2026
35a873a
Merge branch 'main' into users/erman-gurses/add-jax10x-support
erman-gurses Jun 25, 2026
c6bacdd
Test GPU memory error
erman-gurses Jun 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 106 additions & 21 deletions .github/workflows/multi_arch_build_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@ on:
description: "JAX / rocm-jax release ref."
type: string
required: true
build_jaxlib:
description: "Whether this matrix cell should build jaxlib from JAX source."
type: boolean
required: true
rocm_version:
description: "ROCm package version to install and build against."
type: string
Expand All @@ -53,6 +49,19 @@ on:
description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref."
type: string
default: ""
jax_repository:
description: "Repository containing the JAX release branch."
type: string
required: true
build_mode:
description: "Build mode: native or manylinux."
type: string
required: true
gfx_arch:
description: "GFX architecture used for the manylinux image build."
type: string
required: false
default: ""
outputs:
package_index_url:
description: "Package index URL for the multi-arch release bucket."
Expand All @@ -79,10 +88,6 @@ on:
description: "JAX / rocm-jax release ref."
type: string
required: true
build_jaxlib:
description: "Whether this matrix cell should build jaxlib from JAX source."
type: boolean
required: true
rocm_version:
description: "ROCm package version to install and build against."
type: string
Expand All @@ -103,6 +108,19 @@ on:
description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref."
type: string
default: ""
Comment on lines 127 to 134

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These repository and ref inputs aren't commonly provided when triggered by developers, so I'd prefer to keep them as the last inputs in the list. Looks like a few other workflows have similarly started adding inputs beyond them though.

I would move jax_repository, build_mode, and gfx_arch near the top of the inputs list since they will be set frequently.

jax_repository:
description: "Repository containing the JAX release branch."
type: string
required: true

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a default here, or at least mention which repositories are expected in the description? What syntax is this expecting, ROCm/jax or https://github.com/ROCm/jax?

build_mode:
description: "Build mode: native or manylinux."
type: string
required: true

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If there are only two expected string values, for workflow_dispatch you could use a "choice" input type to make this easier to use correctly.

https://docs.github.com/en/actions/reference/workflows-and-actions/workflow-syntax#onworkflow_dispatchinputs

gfx_arch:
description: "GFX architecture used for the manylinux image build."
type: string
required: false
default: ""

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an example to the description? We have a few syntax styles (gfx94X, gfx942, gfx942-dcgpu, etc.) and that helps a lot when working across multiple workflows.


run-name: Build Multi-Arch Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_version }}, Multiarch)

Expand All @@ -121,7 +139,7 @@ jobs:
jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }}
jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/rocm-jax/jax_rocm_plugin/wheelhouse
MANYLINUX_IMAGE_TAG: therock-jax-manylinux:${{ inputs.jax_ref }}-${{ inputs.python_version }}

steps:
- name: Checkout TheRock
Expand All @@ -131,18 +149,19 @@ jobs:
ref: ${{ inputs.ref || github.ref_name }}

- name: Checkout rocm-jax
if: ${{ inputs.build_mode == 'native' }}
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with:
path: rocm-jax
repository: ROCm/rocm-jax
ref: ${{ inputs.jax_ref }}

- name: Checkout JAX
if: ${{ inputs.build_jaxlib }}
if: ${{ inputs.build_mode == 'manylinux' }}
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with:
path: jax
repository: ROCm/jax
path: jax-source
repository: ${{ inputs.jax_repository }}
ref: ${{ inputs.jax_ref }}

- name: Configure Git Identity
Expand All @@ -159,17 +178,16 @@ jobs:
run: |
python -m pip install boto3

- name: Determine build arguments
env:
BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
- name: Set package dist dir
run: |
if [ "$BUILD_JAXLIB" = "true" ]; then
echo "SOURCE_ARG=--jax-source-dir=$GITHUB_WORKSPACE/jax" >> "$GITHUB_ENV"
if [ "${{ inputs.build_mode }}" = "manylinux" ]; then
echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/jax-source/dist" >> "$GITHUB_ENV"
else
echo "SOURCE_ARG=" >> "$GITHUB_ENV"
echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/rocm-jax/jax_rocm_plugin/wheelhouse" >> "$GITHUB_ENV"
fi

- name: Build JAX Wheels
- name: Build JAX Wheels (native, v0.9.1)
if: ${{ inputs.build_mode == 'native' }}
working-directory: rocm-jax
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
Expand All @@ -181,9 +199,77 @@ jobs:
--python-versions="${PYTHON_VERSION}" \
--rocm-version="${ROCM_VERSION}" \
--therock-path="${TAR_URL}" \
${SOURCE_ARG} \
dist_wheels

- name: Build manylinux image
if: ${{ inputs.build_mode == 'manylinux' }}
working-directory: jax-source
env:
BRANCH: ${{ inputs.jax_ref }}
THEROCK_INDEX_URL: https://rocm.nightlies.amd.com/whl-multi-arch/
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
THEROCK_VERSION: ""
GFX_ARCH: ${{ inputs.gfx_arch }}
run: |
curl -fsSL \
-o Dockerfile.jax-manylinux_2_28-therock \
"https://raw.githubusercontent.com/ROCm/rocm-jax/${BRANCH}/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock"
Comment thread
erman-gurses marked this conversation as resolved.
Outdated

docker build \
-t "${MANYLINUX_IMAGE_TAG}" \
--file=Dockerfile.jax-manylinux_2_28-therock \
--build-arg=THEROCK_INDEX_URL="${THEROCK_INDEX_URL}" \
--build-arg=THEROCK_VERSION="${THEROCK_VERSION}" \
--build-arg=GFX_ARCH="${GFX_ARCH}" \
Comment thread
erman-gurses marked this conversation as resolved.
.

- name: Compute wheel version suffix
if: ${{ inputs.build_mode == 'manylinux' }}
run: |
ROCM_VERSION_BASE="${ROCM_VERSION%%+*}"
ROCM_VERSION_BUILD="${ROCM_VERSION#*+}"

if [ "${RELEASE_TYPE}" = "nightly" ]; then
echo "ML_WHEEL_VERSION_SUFFIX=+rocm${ROCM_VERSION_BASE}a$(date -u +%Y%m%d)" >> "$GITHUB_ENV"
else
echo "ML_WHEEL_VERSION_SUFFIX=+rocm${ROCM_VERSION_BASE}.${ROCM_VERSION_BUILD}" >> "$GITHUB_ENV"
fi
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
RELEASE_TYPE: ${{ inputs.release_type }}
Comment thread
erman-gurses marked this conversation as resolved.
Outdated

- name: Build JAX Wheels in manylinux container
if: ${{ inputs.build_mode == 'manylinux' }}
working-directory: jax-source
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
PYTHON_VERSION: ${{ inputs.python_version }}
TAR_URL: ${{ inputs.tar_url }}
run: |
docker run --rm \
Comment on lines +238 to +245

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update the docs at https://github.com/ROCm/TheRock/tree/main/external-builds/jax#build-instructions ? (The general progression for development work and feature enablement should be local builds --> dev builds --> release builds, this jumps straight to release builds in github actions workflows which is the hardest to debug)

--user root \
--env ROCM_VERSION="${ROCM_VERSION}" \
--env PYTHON_VERSION="${PYTHON_VERSION}" \
--env TAR_URL="${TAR_URL}" \
--env ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \
--env PACKAGE_DIST_DIR="/workspace/jax-source/dist" \
--volume "${GITHUB_WORKSPACE}:/workspace" \
--workdir /workspace/jax-source \
"${MANYLINUX_IMAGE_TAG}" \
bash -lc '
python build/build.py build \
--wheels=jax-rocm-plugin,jax-rocm-pjrt \
--python_version="${PYTHON_VERSION}" \
--bazel_startup_options=--bazelrc=build/rocm/rocm.bazelrc \
--bazel_options=--config=rocm_release_wheel \
--bazel_options=--repo_env=ROCM_PATH=$(rocm-sdk path --root) \
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
--bazel_options=--repo_env=ML_WHEEL_TYPE=release \
--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \
--bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD) \
--verbose \
--detailed_timestamped_log \
--output_path=$(pwd)/dist
'

- name: Extract JAX versions from built wheels
id: write_jax_versions
run: |
Expand All @@ -203,7 +289,6 @@ jobs:
--release-type="${{ inputs.release_type }}"



generate_target_to_run:
name: Generate target_to_run
if: ${{ inputs.test_amdgpu_family != '' }}
Expand Down
16 changes: 12 additions & 4 deletions .github/workflows/multi_arch_release_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# Multi-arch release entry point for Linux JAX wheels.
#
# Currently this workflow is uses only rocm-jaxlib-v0.9.1
# This workflow owns the release matrix and delegates each matrix cell to the
# reusable single-build workflow.

Expand Down Expand Up @@ -69,7 +68,6 @@ run-name: Release Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_

permissions:
contents: read
#TODO: Add rocm-jaxlib-v0.10.0 once it is released and tested.
jobs:
build_jax_wheels:
name: Build | py ${{ matrix.python_version }} | jax ${{ matrix.jax_ref }}
Expand All @@ -79,9 +77,17 @@ jobs:
python_version: ["3.11", "3.12", "3.13", "3.14"]
jax_ref:
- "rocm-jaxlib-v0.9.1"
- "rocm-jaxlib-v0.10.0"
include:
- jax_ref: "rocm-jaxlib-v0.9.1"
build_jaxlib: false
jax_repository: "ROCm/rocm-jax"
build_mode: "native"
gfx_arch: ""

- jax_ref: "rocm-jaxlib-v0.10.0"
jax_repository: "ROCm/jax"
build_mode: "manylinux"
gfx_arch: "all"

permissions:
id-token: write
Expand All @@ -91,7 +97,9 @@ jobs:
test_amdgpu_family: ${{ inputs.test_amdgpu_family }}
python_version: ${{ matrix.python_version }}
jax_ref: ${{ matrix.jax_ref }}
build_jaxlib: ${{ matrix.build_jaxlib }}
jax_repository: ${{ matrix.jax_repository }}
build_mode: ${{ matrix.build_mode }}
gfx_arch: ${{ matrix.gfx_arch }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
release_type: ${{ inputs.release_type }}
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test_linux_jax_wheels_partial.yml
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ jobs:
JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH: /opt/rocm/lib/llvm/amdgcn/bitcode
JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH: /opt/rocm/lib/llvm/bin
run: |
python -c "import jax; print(jax.local_devices())"
pytest jax/tests/multi_device_test.py -q --log-cli-level=INFO
pytest jax/tests/core_test.py -q --log-cli-level=INFO
pytest jax/tests/util_test.py -q --log-cli-level=INFO
Expand Down
Loading