-
Notifications
You must be signed in to change notification settings - Fork 268
feat: Add JAX release support v0.10.0 #6054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 17 commits
a4ca6a2
dd528cb
baef7fe
6fc71c7
74ab4d7
19eda7c
c170950
e6f750d
b9a4640
700e043
ec7f744
6c8f879
785787d
1c8dc98
1f13220
41c015b
099878f
88fc90d
22c916c
8b6f2f5
9d58a73
084898b
382f39e
dcf62ab
50c652d
9b1e7ed
ff1f03f
f7448a4
363b89c
0bdd100
3d46f1b
c15d36d
35a873a
c6bacdd
d1d2b54
ccc240e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,14 +29,14 @@ on: | |
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
| required: true | ||
| rocm_package_find_links_url: | ||
| description: "ROCm package index / find-links URL for the manylinux build." | ||
| type: string | ||
| required: true | ||
| tar_url: | ||
| description: "URL to the TheRock tarball used for the build." | ||
| type: string | ||
|
|
@@ -53,6 +53,19 @@ on: | |
| description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref." | ||
| type: string | ||
| default: "" | ||
| jax_repository: | ||
| description: "Repository containing the JAX release branch." | ||
| type: string | ||
| required: true | ||
| build_mode: | ||
| description: "Build mode: native or manylinux." | ||
| type: string | ||
| required: true | ||
| gfx_arch: | ||
| description: "GFX architecture used for the manylinux image build." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
| outputs: | ||
| package_index_url: | ||
| description: "Package index URL for the multi-arch release bucket." | ||
|
|
@@ -79,14 +92,14 @@ on: | |
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
| required: true | ||
| rocm_package_find_links_url: | ||
| description: "ROCm package index / find-links URL for the manylinux build." | ||
| type: string | ||
| required: true | ||
| tar_url: | ||
| description: "URL to the TheRock tarball used for the build." | ||
| type: string | ||
|
|
@@ -103,6 +116,19 @@ on: | |
| description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref." | ||
| type: string | ||
| default: "" | ||
| jax_repository: | ||
| description: "Repository containing the JAX release branch." | ||
| type: string | ||
| required: true | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a default here, or at least mention which repositories are expected in the description? What syntax is this expecting, |
||
| build_mode: | ||
| description: "Build mode: native or manylinux." | ||
| type: string | ||
| required: true | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: If there are only two expected string values, for workflow_dispatch you could use a "choice" input type to make this easier to use correctly. |
||
| gfx_arch: | ||
| description: "GFX architecture used for the manylinux image build." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an example to the description? We have a few syntax styles ( |
||
|
|
||
| run-name: Build Multi-Arch Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_version }}, Multiarch) | ||
|
|
||
|
|
@@ -121,7 +147,7 @@ jobs: | |
| jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }} | ||
| jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }} | ||
| env: | ||
| PACKAGE_DIST_DIR: ${{ github.workspace }}/rocm-jax/jax_rocm_plugin/wheelhouse | ||
| MANYLINUX_IMAGE_TAG: therock-jax-manylinux:${{ inputs.jax_ref }}-${{ inputs.python_version }} | ||
|
|
||
| steps: | ||
| - name: Checkout TheRock | ||
|
|
@@ -138,13 +164,14 @@ jobs: | |
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
| - name: Checkout JAX | ||
| if: ${{ inputs.build_jaxlib }} | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 | ||
| with: | ||
| path: jax | ||
| repository: ROCm/jax | ||
| path: jax-source | ||
| repository: ${{ inputs.jax_repository }} | ||
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
|
|
||
| - name: Configure Git Identity | ||
| run: | | ||
| git config --global user.name "therockbot" | ||
|
|
@@ -155,21 +182,20 @@ jobs: | |
| with: | ||
| python-version: ${{ inputs.python_version }} | ||
|
|
||
| - name: Install boto3 | ||
| - name: Install python deps for CI | ||
| run: | | ||
| python -m pip install boto3 | ||
| pip install -r external-builds/jax/requirements-jax.txt | ||
|
|
||
| - name: Determine build arguments | ||
| env: | ||
| BUILD_JAXLIB: ${{ inputs.build_jaxlib }} | ||
| - name: Set package dist dir | ||
| run: | | ||
| if [ "$BUILD_JAXLIB" = "true" ]; then | ||
| echo "SOURCE_ARG=--jax-source-dir=$GITHUB_WORKSPACE/jax" >> "$GITHUB_ENV" | ||
| if [ "${{ inputs.build_mode }}" = "manylinux" ]; then | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/jax-source/dist" >> "$GITHUB_ENV" | ||
| else | ||
| echo "SOURCE_ARG=" >> "$GITHUB_ENV" | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/rocm-jax/jax_rocm_plugin/wheelhouse" >> "$GITHUB_ENV" | ||
| fi | ||
|
|
||
| - name: Build JAX Wheels | ||
| - name: Build JAX Wheels (native, v0.9.1) | ||
| if: ${{ inputs.build_mode == 'native' }} | ||
| working-directory: rocm-jax | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
|
|
@@ -181,9 +207,63 @@ jobs: | |
| --python-versions="${PYTHON_VERSION}" \ | ||
| --rocm-version="${ROCM_VERSION}" \ | ||
| --therock-path="${TAR_URL}" \ | ||
| ${SOURCE_ARG} \ | ||
| dist_wheels | ||
|
|
||
| - name: Build manylinux image | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| THEROCK_VERSION: "" | ||
| GFX_ARCH: ${{ inputs.gfx_arch }} | ||
| ROCM_PACKAGE_FIND_LINKS_URL: ${{ inputs.rocm_package_find_links_url }} | ||
| run: | | ||
| cp ../rocm-jax/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock \ | ||
| Dockerfile.jax-manylinux_2_28-therock | ||
|
|
||
| docker build \ | ||
| -t "${MANYLINUX_IMAGE_TAG}" \ | ||
| --file=Dockerfile.jax-manylinux_2_28-therock \ | ||
| --build-arg=THEROCK_INDEX_URL="${ROCM_PACKAGE_FIND_LINKS_URL}" \ | ||
| --build-arg=THEROCK_VERSION="${THEROCK_VERSION}" \ | ||
| --build-arg=GFX_ARCH="${GFX_ARCH}" \ | ||
|
erman-gurses marked this conversation as resolved.
|
||
| . | ||
|
|
||
| - name: Determine wheel version suffix | ||
| run: | | ||
| python build_tools/github_actions/determine_version.py \ | ||
| --rocm-version "${{ inputs.rocm_version }}" | ||
|
|
||
| - name: Build JAX Wheels in manylinux container | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
| PYTHON_VERSION: ${{ inputs.python_version }} | ||
| run: | | ||
| docker run --rm \ | ||
|
Comment on lines
+238
to
+245
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also update the docs at https://github.com/ROCm/TheRock/tree/main/external-builds/jax#build-instructions ? (The general progression for development work and feature enablement should be local builds --> dev builds --> release builds, this jumps straight to release builds in github actions workflows which is the hardest to debug) |
||
| --user root \ | ||
| --env ROCM_VERSION="${ROCM_VERSION}" \ | ||
| --env PYTHON_VERSION="${PYTHON_VERSION}" \ | ||
| --env ML_WHEEL_VERSION_SUFFIX="${version_suffix}" \ | ||
| --env PACKAGE_DIST_DIR="/workspace/jax-source/dist" \ | ||
| --volume "${GITHUB_WORKSPACE}:/workspace" \ | ||
| --workdir /workspace/jax-source \ | ||
| "${MANYLINUX_IMAGE_TAG}" \ | ||
| bash -lc ' | ||
| python build/build.py build \ | ||
| --wheels=jax-rocm-plugin,jax-rocm-pjrt \ | ||
| --python_version="${PYTHON_VERSION}" \ | ||
| --bazel_startup_options=--bazelrc=build/rocm/rocm.bazelrc \ | ||
| --bazel_options=--config=rocm_release_wheel \ | ||
| --bazel_options=--repo_env=ROCM_PATH=$(rocm-sdk path --root) \ | ||
| --bazel_options=--repo_env=ML_WHEEL_TYPE=release \ | ||
| --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \ | ||
| --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD) \ | ||
| --verbose \ | ||
| --detailed_timestamped_log \ | ||
| --output_path=$(pwd)/dist | ||
| ' | ||
|
|
||
| - name: Extract JAX versions from built wheels | ||
| id: write_jax_versions | ||
| run: | | ||
|
|
@@ -203,7 +283,6 @@ jobs: | |
| --release-type="${{ inputs.release_type }}" | ||
|
|
||
|
|
||
|
|
||
| generate_target_to_run: | ||
| name: Generate target_to_run | ||
| if: ${{ inputs.test_amdgpu_family != '' }} | ||
|
|
@@ -236,7 +315,7 @@ jobs: | |
| test_amdgpu_family: ${{ inputs.test_amdgpu_family }} | ||
| package_index_url: ${{ needs.build_jax_wheels.outputs.package_index_url }} | ||
| rocm_version: ${{ inputs.rocm_version }} | ||
| tar_url: ${{ inputs.tar_url }} | ||
| rocm_package_find_links_url: ${{ inputs.rocm_package_find_links_url }} | ||
| release_type: ${{ inputs.release_type }} | ||
| python_version: ${{ inputs.python_version }} | ||
| repository: ${{ inputs.repository || github.repository }} | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -3,7 +3,6 @@ | |||||||||
|
|
||||||||||
| # Multi-arch release entry point for Linux JAX wheels. | ||||||||||
| # | ||||||||||
| # Currently this workflow is uses only rocm-jaxlib-v0.9.1 | ||||||||||
| # This workflow owns the release matrix and delegates each matrix cell to the | ||||||||||
| # reusable single-build workflow. | ||||||||||
|
|
||||||||||
|
|
@@ -24,6 +23,10 @@ on: | |||||||||
| description: "ROCm package version to build against." | ||||||||||
| type: string | ||||||||||
| required: true | ||||||||||
| rocm_package_find_links_url: | ||||||||||
| description: "ROCm package index / find-links URL for the JAX manylinux build." | ||||||||||
| type: string | ||||||||||
| default: "https://rocm.devreleases.amd.com/whl-multi-arch/" | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This URL is not compatible with TheRock/.github/workflows/multi_arch_build_portable_linux_pytorch_wheels.yml Lines 115 to 118 in 1371136
A |
||||||||||
| tar_url: | ||||||||||
| description: "URL to the TheRock tarball used for the build." | ||||||||||
| type: string | ||||||||||
|
|
@@ -52,6 +55,10 @@ on: | |||||||||
| description: "ROCm package version to build against." | ||||||||||
| type: string | ||||||||||
| required: true | ||||||||||
| rocm_package_find_links_url: | ||||||||||
| description: "ROCm package index / find-links URL for the JAX manylinux build." | ||||||||||
| type: string | ||||||||||
| default: "https://rocm.devreleases.amd.com/whl-multi-arch/" | ||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove default that won't work here too |
||||||||||
| tar_url: | ||||||||||
| description: "URL to the TheRock tarball used for the build." | ||||||||||
| type: string | ||||||||||
|
|
@@ -69,7 +76,6 @@ run-name: Release Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_ | |||||||||
|
|
||||||||||
| permissions: | ||||||||||
| contents: read | ||||||||||
| #TODO: Add rocm-jaxlib-v0.10.0 once it is released and tested. | ||||||||||
| jobs: | ||||||||||
| build_jax_wheels: | ||||||||||
| name: Build | py ${{ matrix.python_version }} | jax ${{ matrix.jax_ref }} | ||||||||||
|
|
@@ -79,9 +85,17 @@ jobs: | |||||||||
| python_version: ["3.11", "3.12", "3.13", "3.14"] | ||||||||||
| jax_ref: | ||||||||||
| - "rocm-jaxlib-v0.9.1" | ||||||||||
| - "rocm-jaxlib-v0.10.0" | ||||||||||
| include: | ||||||||||
| - jax_ref: "rocm-jaxlib-v0.9.1" | ||||||||||
| build_jaxlib: false | ||||||||||
| jax_repository: "ROCm/rocm-jax" | ||||||||||
| build_mode: "native" | ||||||||||
| gfx_arch: "" | ||||||||||
|
|
||||||||||
| - jax_ref: "rocm-jaxlib-v0.10.0" | ||||||||||
| jax_repository: "ROCm/jax" | ||||||||||
| build_mode: "manylinux" | ||||||||||
| gfx_arch: "device-all" | ||||||||||
|
|
||||||||||
| permissions: | ||||||||||
| id-token: write | ||||||||||
|
|
@@ -91,8 +105,11 @@ jobs: | |||||||||
| test_amdgpu_family: ${{ inputs.test_amdgpu_family }} | ||||||||||
| python_version: ${{ matrix.python_version }} | ||||||||||
| jax_ref: ${{ matrix.jax_ref }} | ||||||||||
| build_jaxlib: ${{ matrix.build_jaxlib }} | ||||||||||
| jax_repository: ${{ matrix.jax_repository }} | ||||||||||
| build_mode: ${{ matrix.build_mode }} | ||||||||||
| gfx_arch: ${{ matrix.gfx_arch }} | ||||||||||
| rocm_version: ${{ inputs.rocm_version }} | ||||||||||
| rocm_package_find_links_url: ${{ inputs.rocm_package_find_links_url }} | ||||||||||
| tar_url: ${{ inputs.tar_url }} | ||||||||||
| release_type: ${{ inputs.release_type }} | ||||||||||
| repository: ${{ inputs.repository || github.repository }} | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -18,8 +18,8 @@ on: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: ROCm version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| required: true | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type: string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tar_url: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: URL to TheRock tarball | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_package_find_links_url: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: "ROCm package index / find-links URL for installing ROCm packages." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| required: true | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type: string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| release_type: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -76,10 +76,10 @@ on: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: ROCm version | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type: string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default: "7.14.0.dev0" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tar_url: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: URL to TheRock tarball | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm_package_find_links_url: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: "ROCm package index / find-links URL for installing ROCm packages." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| required: true | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type: string | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| default: "" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| release_type: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| description: 'Release type; developer-triggered jobs should use "dev".' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| type: choice | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -149,9 +149,9 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| env: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| VIRTUAL_ENV: ${{ github.workspace }}/.venv | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TAR_URL: ${{ inputs.tar_url }} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PYTHON_VERSION: ${{ inputs.python_version }} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| WHEEL_INDEX_URL: ${{ inputs.package_index_url }} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ROCM_PACKAGE_FIND_LINKS_URL: ${{ inputs.rocm_package_find_links_url }} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| steps: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - name: Checkout | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -192,19 +192,6 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pip install -r external-builds/jax/requirements-jax.txt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - name: Configure ROCm from TheRock tarball | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| DEST="/opt/rocm-${{ inputs.rocm_version }}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| TARBALL="${RUNNER_TEMP}/therock.tar.gz" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rm -rf "${DEST}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| mkdir -p "${DEST}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| wget -O "${TARBALL}" "${TAR_URL}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tar -xzf "${TARBALL}" -C "${DEST}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ln -sfn "${DEST}" /opt/rocm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - name: Determine JAX package versions | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if [ -n "${{ inputs.jax_plugin_version }}" ] && [ -n "${{ inputs.jax_version }}" ]; then | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -220,6 +207,18 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --jax-requirements rocm-jax/build/requirements.txt | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fi | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - name: Install ROCm packages | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python -m pip install --upgrade pip | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python -m pip install --pre \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| --index-url "${ROCM_PACKAGE_FIND_LINKS_URL}" \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait. Which URL format do you want? They are different. The variable name here says "find links url" but it's installing with |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm-sdk-devel \ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do JAX tests need See my other comment, this should just install |
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm-sdk-device-gfx942 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Use this code instead:
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rocm-sdk init | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| echo "ROCM_ROOT=$(rocm-sdk path --root)" >> "$GITHUB_ENV" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ln -sfn "$(rocm-sdk path --root)" /opt/rocm | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please don't symlink rocm python packages to
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - name: Install JAX wheels from package index | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| echo "Installing from:" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -249,6 +248,7 @@ jobs: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| JAX_ROCM_PLUGIN_INTERNAL_BITCODE_PATH: /opt/rocm/lib/llvm/amdgcn/bitcode | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| JAX_ROCM_PLUGIN_INTERNAL_LLD_PATH: /opt/rocm/lib/llvm/bin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| run: | | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| python -c "import jax; print(jax.local_devices())" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pytest jax/tests/multi_device_test.py -q --log-cli-level=INFO | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pytest jax/tests/core_test.py -q --log-cli-level=INFO | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| pytest jax/tests/util_test.py -q --log-cli-level=INFO | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These
repositoryandrefinputs aren't commonly provided when triggered by developers, so I'd prefer to keep them as the last inputs in the list. Looks like a few other workflows have similarly started adding inputs beyond them though.I would move
jax_repository,build_mode, andgfx_archnear the top of the inputs list since they will be set frequently.