-
Notifications
You must be signed in to change notification settings - Fork 268
feat: Add JAX release support v0.10.0 #6054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
a4ca6a2
dd528cb
baef7fe
6fc71c7
74ab4d7
19eda7c
c170950
e6f750d
b9a4640
700e043
ec7f744
6c8f879
785787d
1c8dc98
1f13220
41c015b
099878f
88fc90d
22c916c
8b6f2f5
9d58a73
084898b
382f39e
dcf62ab
50c652d
9b1e7ed
ff1f03f
f7448a4
363b89c
0bdd100
3d46f1b
c15d36d
35a873a
c6bacdd
d1d2b54
ccc240e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,18 +25,31 @@ on: | |
| description: "Python version to build wheels for." | ||
| type: string | ||
| required: true | ||
| jax_repository: | ||
| description: "Repository containing the JAX release branch." | ||
| type: string | ||
| required: true | ||
| build_mode: | ||
| description: "Build mode: native or manylinux." | ||
| type: string | ||
| required: true | ||
| gfx_arch: | ||
| description: "ROCm device package selector used for the manylinux build (e.g. device-gfx942, device-gfx950, or device-all)." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
| jax_ref: | ||
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
| required: true | ||
| rocm_package_index_url: | ||
| description: "URL for pip --index-url to install ROCm packages for the JAX manylinux build." | ||
| type: string | ||
| default: "" | ||
| tar_url: | ||
| description: "URL to the TheRock tarball used for the build." | ||
| type: string | ||
|
|
@@ -75,18 +88,34 @@ on: | |
| description: "Python version to build wheels for." | ||
| type: string | ||
| required: true | ||
| jax_repository: | ||
| description: "GitHub repository in owner/repo format (e.g. ROCm/jax or ROCm/rocm-jax)." | ||
| type: string | ||
| default: "ROCm/jax" | ||
| build_mode: | ||
| description: "Build mode." | ||
| type: choice | ||
| options: | ||
| - native | ||
| - manylinux | ||
| required: true | ||
| gfx_arch: | ||
| description: "ROCm device package selector used for the manylinux build (e.g. device-gfx942, device-gfx950, or device-all)." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
| jax_ref: | ||
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
| required: true | ||
| rocm_package_index_url: | ||
| description: "URL for pip --index-url to install ROCm packages for the JAX manylinux build." | ||
| type: string | ||
| default: "" | ||
| tar_url: | ||
| description: "URL to the TheRock tarball used for the build." | ||
| type: string | ||
|
|
@@ -103,7 +132,6 @@ on: | |
| description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref." | ||
| type: string | ||
| default: "" | ||
|
|
||
| run-name: Build Multi-Arch Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_version }}, Multiarch) | ||
|
|
||
| permissions: | ||
|
|
@@ -121,7 +149,7 @@ jobs: | |
| jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }} | ||
| jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }} | ||
| env: | ||
| PACKAGE_DIST_DIR: ${{ github.workspace }}/rocm-jax/jax_rocm_plugin/wheelhouse | ||
| MANYLINUX_IMAGE_TAG: therock-jax-manylinux:${{ inputs.jax_ref }}-${{ inputs.python_version }} | ||
|
|
||
| steps: | ||
| - name: Checkout TheRock | ||
|
|
@@ -138,13 +166,14 @@ jobs: | |
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
| - name: Checkout JAX | ||
| if: ${{ inputs.build_jaxlib }} | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 | ||
| with: | ||
| path: jax | ||
| repository: ROCm/jax | ||
| path: jax-source | ||
| repository: ${{ inputs.jax_repository }} | ||
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
|
|
||
| - name: Configure Git Identity | ||
| run: | | ||
| git config --global user.name "therockbot" | ||
|
|
@@ -155,21 +184,20 @@ jobs: | |
| with: | ||
| python-version: ${{ inputs.python_version }} | ||
|
|
||
| - name: Install boto3 | ||
| - name: Install python deps for CI | ||
| run: | | ||
| python -m pip install boto3 | ||
| pip install -r external-builds/jax/requirements-jax.txt | ||
|
|
||
| - name: Determine build arguments | ||
| env: | ||
| BUILD_JAXLIB: ${{ inputs.build_jaxlib }} | ||
| - name: Set package dist dir | ||
| run: | | ||
| if [ "$BUILD_JAXLIB" = "true" ]; then | ||
| echo "SOURCE_ARG=--jax-source-dir=$GITHUB_WORKSPACE/jax" >> "$GITHUB_ENV" | ||
| if [ "${{ inputs.build_mode }}" = "manylinux" ]; then | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/jax-source/dist" >> "$GITHUB_ENV" | ||
| else | ||
| echo "SOURCE_ARG=" >> "$GITHUB_ENV" | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/rocm-jax/jax_rocm_plugin/wheelhouse" >> "$GITHUB_ENV" | ||
| fi | ||
|
|
||
| - name: Build JAX Wheels | ||
| - name: Build JAX Wheels (native, v0.9.1) | ||
| if: ${{ inputs.build_mode == 'native' }} | ||
| working-directory: rocm-jax | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
|
|
@@ -181,9 +209,63 @@ jobs: | |
| --python-versions="${PYTHON_VERSION}" \ | ||
| --rocm-version="${ROCM_VERSION}" \ | ||
| --therock-path="${TAR_URL}" \ | ||
| ${SOURCE_ARG} \ | ||
| dist_wheels | ||
|
|
||
| - name: Build manylinux image | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| THEROCK_VERSION: "" | ||
| GFX_ARCH: ${{ inputs.gfx_arch }} | ||
| ROCM_PACKAGE_INDEX_URL: ${{ inputs.rocm_package_index_url }} | ||
| run: | | ||
| cp ../rocm-jax/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock \ | ||
| Dockerfile.jax-manylinux_2_28-therock | ||
|
|
||
| docker build \ | ||
| -t "${MANYLINUX_IMAGE_TAG}" \ | ||
| --file=Dockerfile.jax-manylinux_2_28-therock \ | ||
| --build-arg=THEROCK_INDEX_URL="${ROCM_PACKAGE_INDEX_URL}" \ | ||
| --build-arg=THEROCK_VERSION="${THEROCK_VERSION}" \ | ||
| --build-arg=GFX_ARCH="${GFX_ARCH}" \ | ||
| . | ||
|
|
||
| - name: Determine wheel version suffix | ||
| run: | | ||
| python build_tools/github_actions/determine_version.py \ | ||
| --rocm-version "${{ inputs.rocm_version }}" | ||
|
|
||
| - name: Build JAX Wheels in manylinux container | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
| PYTHON_VERSION: ${{ inputs.python_version }} | ||
| run: | | ||
| docker run --rm \ | ||
|
Comment on lines
+238
to
+245
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also update the docs at https://github.com/ROCm/TheRock/tree/main/external-builds/jax#build-instructions ? (The general progression for development work and feature enablement should be local builds --> dev builds --> release builds, this jumps straight to release builds in github actions workflows which is the hardest to debug) |
||
| --user root \ | ||
| --env ROCM_VERSION="${ROCM_VERSION}" \ | ||
| --env PYTHON_VERSION="${PYTHON_VERSION}" \ | ||
| --env ML_WHEEL_VERSION_SUFFIX="${version_suffix}" \ | ||
| --env PACKAGE_DIST_DIR="/workspace/jax-source/dist" \ | ||
| --volume "${GITHUB_WORKSPACE}:/workspace" \ | ||
| --workdir /workspace/jax-source \ | ||
| "${MANYLINUX_IMAGE_TAG}" \ | ||
| bash -lc ' | ||
| python build/build.py build \ | ||
| --wheels=jax-rocm-plugin,jax-rocm-pjrt \ | ||
| --python_version="${PYTHON_VERSION}" \ | ||
| --bazel_startup_options=--bazelrc=build/rocm/rocm.bazelrc \ | ||
| --bazel_options=--config=rocm_release_wheel \ | ||
| --bazel_options=--repo_env=ROCM_PATH=$(rocm-sdk path --root) \ | ||
| --bazel_options=--repo_env=ML_WHEEL_TYPE=release \ | ||
| --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \ | ||
| --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD) \ | ||
| --verbose \ | ||
| --detailed_timestamped_log \ | ||
| --output_path=$(pwd)/dist | ||
| ' | ||
|
|
||
| - name: Extract JAX versions from built wheels | ||
| id: write_jax_versions | ||
| run: | | ||
|
|
@@ -203,7 +285,6 @@ jobs: | |
| --release-type="${{ inputs.release_type }}" | ||
|
|
||
|
|
||
|
|
||
| generate_target_to_run: | ||
| name: Generate target_to_run | ||
| if: ${{ inputs.test_amdgpu_family != '' }} | ||
|
|
@@ -245,7 +326,7 @@ jobs: | |
| test_amdgpu_family: ${{ inputs.test_amdgpu_family }} | ||
| package_index_url: ${{ needs.build_jax_wheels.outputs.package_index_url }} | ||
| rocm_version: ${{ inputs.rocm_version }} | ||
| tar_url: ${{ inputs.tar_url }} | ||
| rocm_package_index_url: ${{ inputs.rocm_package_index_url }} | ||
| release_type: ${{ inputs.release_type }} | ||
| python_version: ${{ inputs.python_version }} | ||
| repository: ${{ inputs.repository || github.repository }} | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.