Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
a4ca6a2
Add JAX release support v0.10.0
erman-gurses Jun 23, 2026
dd528cb
Fix gfx_arch: all
erman-gurses Jun 23, 2026
baef7fe
Remove tarball
erman-gurses Jun 23, 2026
6fc71c7
Test
erman-gurses Jun 23, 2026
74ab4d7
Test
erman-gurses Jun 23, 2026
19eda7c
Add back tarball for 0.9.0
erman-gurses Jun 23, 2026
c170950
Test
erman-gurses Jun 23, 2026
e6f750d
Test
erman-gurses Jun 23, 2026
b9a4640
Test
erman-gurses Jun 24, 2026
700e043
Test
erman-gurses Jun 24, 2026
ec7f744
Test
erman-gurses Jun 24, 2026
6c8f879
Test
erman-gurses Jun 24, 2026
785787d
Test
erman-gurses Jun 24, 2026
1c8dc98
Test
erman-gurses Jun 24, 2026
1f13220
Use requirements-jax.txt
erman-gurses Jun 24, 2026
41c015b
Test
erman-gurses Jun 24, 2026
099878f
Remove condition for determine_version.py
erman-gurses Jun 24, 2026
88fc90d
Add options for build_mode
erman-gurses Jun 24, 2026
22c916c
Add default for jax_repository
erman-gurses Jun 24, 2026
8b6f2f5
Update gfx_arch description
erman-gurses Jun 24, 2026
9d58a73
Move repository and ref inputs to the bottom
erman-gurses Jun 24, 2026
084898b
Move jax_repositor,y build_mode, and gfx_arch inputs up
erman-gurses Jun 24, 2026
382f39e
Solve --find_links inconsistency use index_url
erman-gurses Jun 24, 2026
dcf62ab
Add package_index_url output for better ROCm package indexing
erman-gurses Jun 25, 2026
50c652d
Add unit test
erman-gurses Jun 25, 2026
9b1e7ed
Increase shared memory size
erman-gurses Jun 25, 2026
ff1f03f
Reduce memory pressure
erman-gurses Jun 25, 2026
f7448a4
Test memory
erman-gurses Jun 25, 2026
363b89c
Test
erman-gurses Jun 25, 2026
0bdd100
Test
erman-gurses Jun 25, 2026
3d46f1b
Get original testing script
erman-gurses Jun 25, 2026
c15d36d
Update README.md
erman-gurses Jun 25, 2026
35a873a
Merge branch 'main' into users/erman-gurses/add-jax10x-support
erman-gurses Jun 25, 2026
c6bacdd
Test GPU memory error
erman-gurses Jun 25, 2026
d1d2b54
Merge branch 'main' into users/erman-gurses/add-jax10x-support
erman-gurses Jun 25, 2026
ccc240e
Fix JAX ROCm test GPU isolation in CI
erman-gurses Jun 25, 2026
604ad16
Update GPU test setup
erman-gurses Jun 25, 2026
ec428f0
Test JAX
erman-gurses Jun 25, 2026
db44800
Remove old command
erman-gurses Jun 25, 2026
afe0d12
Update README
erman-gurses Jun 25, 2026
9c9e172
Update README
erman-gurses Jun 25, 2026
e130075
Fix format for md
erman-gurses Jun 25, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 106 additions & 24 deletions .github/workflows/multi_arch_build_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ on:
description: "JAX / rocm-jax release ref."
type: string
required: true
build_jaxlib:
description: "Whether this matrix cell should build jaxlib from JAX source."
type: boolean
required: true
rocm_version:
description: "ROCm package version to install and build against."
type: string
required: true
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for the manylinux build."
type: string
required: true
tar_url:
description: "URL to the TheRock tarball used for the build."
type: string
Expand All @@ -53,6 +53,19 @@ on:
description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref."
type: string
default: ""
jax_repository:
description: "Repository containing the JAX release branch."
type: string
required: true
build_mode:
description: "Build mode: native or manylinux."
type: string
required: true
gfx_arch:
description: "GFX architecture used for the manylinux image build."
type: string
required: false
default: ""
outputs:
package_index_url:
description: "Package index URL for the multi-arch release bucket."
Expand All @@ -79,14 +92,14 @@ on:
description: "JAX / rocm-jax release ref."
type: string
required: true
build_jaxlib:
description: "Whether this matrix cell should build jaxlib from JAX source."
type: boolean
required: true
rocm_version:
description: "ROCm package version to install and build against."
type: string
required: true
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for the manylinux build."
type: string
required: true
tar_url:
description: "URL to the TheRock tarball used for the build."
type: string
Expand All @@ -103,6 +116,19 @@ on:
description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref."
type: string
default: ""
Comment thread
erman-gurses marked this conversation as resolved.
jax_repository:
description: "Repository containing the JAX release branch."
type: string
required: true
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
build_mode:
description: "Build mode: native or manylinux."
type: string
required: true
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
gfx_arch:
description: "GFX architecture used for the manylinux image build."
type: string
required: false
default: ""
Comment thread
erman-gurses marked this conversation as resolved.
Outdated

run-name: Build Multi-Arch Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_version }}, Multiarch)

Expand All @@ -121,7 +147,7 @@ jobs:
jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }}
jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }}
env:
PACKAGE_DIST_DIR: ${{ github.workspace }}/rocm-jax/jax_rocm_plugin/wheelhouse
MANYLINUX_IMAGE_TAG: therock-jax-manylinux:${{ inputs.jax_ref }}-${{ inputs.python_version }}

steps:
- name: Checkout TheRock
Expand All @@ -138,13 +164,15 @@ jobs:
ref: ${{ inputs.jax_ref }}

- name: Checkout JAX
if: ${{ inputs.build_jaxlib }}
if: ${{ inputs.build_mode == 'manylinux' }}
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
with:
path: jax
repository: ROCm/jax
path: jax-source
repository: ${{ inputs.jax_repository }}
ref: ${{ inputs.jax_ref }}



- name: Configure Git Identity
run: |
git config --global user.name "therockbot"
Expand All @@ -155,21 +183,21 @@ jobs:
with:
python-version: ${{ inputs.python_version }}

- name: Install boto3
- name: Install python deps for CI
if: ${{ inputs.build_mode == 'manylinux' }}
run: |
python -m pip install boto3
pip install -r external-builds/jax/requirements-jax.txt

- name: Determine build arguments
env:
BUILD_JAXLIB: ${{ inputs.build_jaxlib }}
- name: Set package dist dir
run: |
if [ "$BUILD_JAXLIB" = "true" ]; then
echo "SOURCE_ARG=--jax-source-dir=$GITHUB_WORKSPACE/jax" >> "$GITHUB_ENV"
if [ "${{ inputs.build_mode }}" = "manylinux" ]; then
echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/jax-source/dist" >> "$GITHUB_ENV"
else
echo "SOURCE_ARG=" >> "$GITHUB_ENV"
echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/rocm-jax/jax_rocm_plugin/wheelhouse" >> "$GITHUB_ENV"
fi

- name: Build JAX Wheels
- name: Build JAX Wheels (native, v0.9.1)
if: ${{ inputs.build_mode == 'native' }}
working-directory: rocm-jax
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
Expand All @@ -181,9 +209,64 @@ jobs:
--python-versions="${PYTHON_VERSION}" \
--rocm-version="${ROCM_VERSION}" \
--therock-path="${TAR_URL}" \
${SOURCE_ARG} \
dist_wheels

- name: Build manylinux image
if: ${{ inputs.build_mode == 'manylinux' }}
working-directory: jax-source
env:
THEROCK_VERSION: ""
GFX_ARCH: ${{ inputs.gfx_arch }}
ROCM_PACKAGE_FIND_LINKS_URL: ${{ inputs.rocm_package_find_links_url }}
run: |
cp ../rocm-jax/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock \
Dockerfile.jax-manylinux_2_28-therock

docker build \
-t "${MANYLINUX_IMAGE_TAG}" \
--file=Dockerfile.jax-manylinux_2_28-therock \
--build-arg=THEROCK_INDEX_URL="${ROCM_PACKAGE_FIND_LINKS_URL}" \
--build-arg=THEROCK_VERSION="${THEROCK_VERSION}" \
--build-arg=GFX_ARCH="${GFX_ARCH}" \
Comment thread
erman-gurses marked this conversation as resolved.
.

- name: Determine wheel version suffix
if: ${{ inputs.build_mode == 'manylinux' }}
run: |
python build_tools/github_actions/determine_version.py \
--rocm-version "${{ inputs.rocm_version }}"

- name: Build JAX Wheels in manylinux container
if: ${{ inputs.build_mode == 'manylinux' }}
working-directory: jax-source
env:
ROCM_VERSION: ${{ inputs.rocm_version }}
PYTHON_VERSION: ${{ inputs.python_version }}
run: |
docker run --rm \
Comment thread
erman-gurses marked this conversation as resolved.
--user root \
--env ROCM_VERSION="${ROCM_VERSION}" \
--env PYTHON_VERSION="${PYTHON_VERSION}" \
--env ML_WHEEL_VERSION_SUFFIX="${version_suffix}" \
--env PACKAGE_DIST_DIR="/workspace/jax-source/dist" \
--volume "${GITHUB_WORKSPACE}:/workspace" \
--workdir /workspace/jax-source \
"${MANYLINUX_IMAGE_TAG}" \
bash -lc '
python build/build.py build \
--wheels=jax-rocm-plugin,jax-rocm-pjrt \
--python_version="${PYTHON_VERSION}" \
--bazel_startup_options=--bazelrc=build/rocm/rocm.bazelrc \
--bazel_options=--config=rocm_release_wheel \
--bazel_options=--repo_env=ROCM_PATH=$(rocm-sdk path --root) \
--bazel_options=--repo_env=ML_WHEEL_TYPE=release \
--bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \
--bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD) \
--verbose \
--detailed_timestamped_log \
--output_path=$(pwd)/dist
'

- name: Extract JAX versions from built wheels
id: write_jax_versions
run: |
Expand All @@ -203,7 +286,6 @@ jobs:
--release-type="${{ inputs.release_type }}"



generate_target_to_run:
name: Generate target_to_run
if: ${{ inputs.test_amdgpu_family != '' }}
Expand Down Expand Up @@ -236,7 +318,7 @@ jobs:
test_amdgpu_family: ${{ inputs.test_amdgpu_family }}
package_index_url: ${{ needs.build_jax_wheels.outputs.package_index_url }}
rocm_version: ${{ inputs.rocm_version }}
tar_url: ${{ inputs.tar_url }}
rocm_package_find_links_url: ${{ inputs.rocm_package_find_links_url }}
release_type: ${{ inputs.release_type }}
python_version: ${{ inputs.python_version }}
repository: ${{ inputs.repository || github.repository }}
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/multi_arch_release_linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ jobs:
}

trigger_release_jax_wheels:
needs: [build_artifacts, build_tarballs]
needs: [build_artifacts, build_tarballs, build_python_packages]
name: Trigger Release JAX Wheels
runs-on: ubuntu-24.04
permissions:
Expand All @@ -260,6 +260,7 @@ jobs:
"test_amdgpu_family": "${{ contains(fromJSON(inputs.build_config).dist_amdgpu_families, 'gfx94X-dcgpu') && 'gfx94X-dcgpu' || '' }}",
"release_type": "${{ inputs.release_type }}",
"rocm_version": "${{ inputs.rocm_package_version }}",
"rocm_package_find_links_url": "${{ needs.build_python_packages.outputs.package_find_links_url }}",
"tar_url": "${{ fromJSON(needs.build_tarballs.outputs.tarball_urls).multiarch }}",
"repository": "${{ inputs.repository || github.repository }}",
"ref": "${{ inputs.ref || '' }}"
Expand Down
25 changes: 21 additions & 4 deletions .github/workflows/multi_arch_release_linux_jax_wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# Multi-arch release entry point for Linux JAX wheels.
#
# Currently this workflow is uses only rocm-jaxlib-v0.9.1
# This workflow owns the release matrix and delegates each matrix cell to the
# reusable single-build workflow.

Expand All @@ -24,6 +23,10 @@ on:
description: "ROCm package version to build against."
type: string
required: true
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for the JAX manylinux build."
type: string
default: "https://rocm.devreleases.amd.com/whl-multi-arch/"

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This URL is not compatible with --find-links, it is only compatible with --index-url or --extra-index-url. I'd leave this default off, as we have in the pytorch workflows:

rocm_package_index_url:
description: URL for pip --index-url to install ROCm packages (PEP 503 index, e.g. repo.amd.com; use this OR rocm_package_find_links_url)
type: string
default: ""

A --find-links URL would be more like https://therock-dev-artifacts.s3.amazonaws.com/27944292460-linux/python/index.html

tar_url:
description: "URL to the TheRock tarball used for the build."
type: string
Expand Down Expand Up @@ -52,6 +55,10 @@ on:
description: "ROCm package version to build against."
type: string
required: true
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for the JAX manylinux build."
type: string
default: "https://rocm.devreleases.amd.com/whl-multi-arch/"
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
tar_url:
description: "URL to the TheRock tarball used for the build."
type: string
Expand All @@ -69,7 +76,6 @@ run-name: Release Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_

permissions:
contents: read
#TODO: Add rocm-jaxlib-v0.10.0 once it is released and tested.
jobs:
build_jax_wheels:
name: Build | py ${{ matrix.python_version }} | jax ${{ matrix.jax_ref }}
Expand All @@ -79,9 +85,17 @@ jobs:
python_version: ["3.11", "3.12", "3.13", "3.14"]
jax_ref:
- "rocm-jaxlib-v0.9.1"
- "rocm-jaxlib-v0.10.0"
include:
- jax_ref: "rocm-jaxlib-v0.9.1"
build_jaxlib: false
jax_repository: "ROCm/rocm-jax"
build_mode: "native"
gfx_arch: ""

- jax_ref: "rocm-jaxlib-v0.10.0"
jax_repository: "ROCm/jax"
build_mode: "manylinux"
gfx_arch: "device-all"

permissions:
id-token: write
Expand All @@ -91,8 +105,11 @@ jobs:
test_amdgpu_family: ${{ inputs.test_amdgpu_family }}
python_version: ${{ matrix.python_version }}
jax_ref: ${{ matrix.jax_ref }}
build_jaxlib: ${{ matrix.build_jaxlib }}
jax_repository: ${{ matrix.jax_repository }}
build_mode: ${{ matrix.build_mode }}
gfx_arch: ${{ matrix.gfx_arch }}
rocm_version: ${{ inputs.rocm_version }}
rocm_package_find_links_url: ${{ inputs.rocm_package_find_links_url }}
tar_url: ${{ inputs.tar_url }}
release_type: ${{ inputs.release_type }}
repository: ${{ inputs.repository || github.repository }}
Expand Down
38 changes: 19 additions & 19 deletions .github/workflows/test_linux_jax_wheels_partial.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ on:
description: ROCm version
required: true
type: string
tar_url:
description: URL to TheRock tarball
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for installing ROCm packages."
required: true
type: string
release_type:
Expand Down Expand Up @@ -76,10 +76,10 @@ on:
description: ROCm version
type: string
default: "7.14.0.dev0"
tar_url:
description: URL to TheRock tarball
rocm_package_find_links_url:
description: "ROCm package index / find-links URL for installing ROCm packages."
required: true
type: string
default: ""
release_type:
description: 'Release type; developer-triggered jobs should use "dev".'
type: choice
Expand Down Expand Up @@ -149,9 +149,9 @@ jobs:

env:
VIRTUAL_ENV: ${{ github.workspace }}/.venv
TAR_URL: ${{ inputs.tar_url }}
PYTHON_VERSION: ${{ inputs.python_version }}
WHEEL_INDEX_URL: ${{ inputs.package_index_url }}
ROCM_PACKAGE_FIND_LINKS_URL: ${{ inputs.rocm_package_find_links_url }}

steps:
- name: Checkout
Expand Down Expand Up @@ -192,19 +192,6 @@ jobs:
run: |
pip install -r external-builds/jax/requirements-jax.txt

- name: Configure ROCm from TheRock tarball
run: |
DEST="/opt/rocm-${{ inputs.rocm_version }}"
TARBALL="${RUNNER_TEMP}/therock.tar.gz"

rm -rf "${DEST}"
mkdir -p "${DEST}"

wget -O "${TARBALL}" "${TAR_URL}"
tar -xzf "${TARBALL}" -C "${DEST}"

ln -sfn "${DEST}" /opt/rocm

- name: Determine JAX package versions
run: |
if [ -n "${{ inputs.jax_plugin_version }}" ] && [ -n "${{ inputs.jax_version }}" ]; then
Expand All @@ -220,6 +207,18 @@ jobs:
--jax-requirements rocm-jax/build/requirements.txt
fi

- name: Install ROCm packages
run: |
python -m pip install --upgrade pip
python -m pip install --pre \
--index-url "${ROCM_PACKAGE_FIND_LINKS_URL}" \
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
rocm \
rocm-sdk-devel \
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
rocm-sdk-device-gfx942
Comment thread
erman-gurses marked this conversation as resolved.
Outdated
rocm-sdk init
echo "ROCM_ROOT=$(rocm-sdk path --root)" >> "$GITHUB_ENV"
ln -sfn "$(rocm-sdk path --root)" /opt/rocm
Comment thread
erman-gurses marked this conversation as resolved.
Outdated

- name: Install JAX wheels from package index
run: |
echo "Installing from:"
Expand Down Expand Up @@ -249,6 +248,7 @@ jobs:
JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH: /opt/rocm/lib/llvm/amdgcn/bitcode
JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH: /opt/rocm/lib/llvm/bin
run: |
python -c "import jax; print(jax.local_devices())"
pytest jax/tests/multi_device_test.py -q --log-cli-level=INFO
pytest jax/tests/core_test.py -q --log-cli-level=INFO
pytest jax/tests/util_test.py -q --log-cli-level=INFO
Expand Down
Loading