-
Notifications
You must be signed in to change notification settings - Fork 268
feat: Add JAX release support v0.10.0 #6054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
a4ca6a2
dd528cb
baef7fe
6fc71c7
74ab4d7
19eda7c
c170950
e6f750d
b9a4640
700e043
ec7f744
6c8f879
785787d
1c8dc98
1f13220
41c015b
099878f
88fc90d
22c916c
8b6f2f5
9d58a73
084898b
382f39e
dcf62ab
50c652d
9b1e7ed
ff1f03f
f7448a4
363b89c
0bdd100
3d46f1b
c15d36d
35a873a
c6bacdd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,10 +29,6 @@ on: | |
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
|
|
@@ -53,6 +49,19 @@ on: | |
| description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref." | ||
| type: string | ||
| default: "" | ||
| jax_repository: | ||
| description: "Repository containing the JAX release branch." | ||
| type: string | ||
| required: true | ||
| build_mode: | ||
| description: "Build mode: native or manylinux." | ||
| type: string | ||
| required: true | ||
| gfx_arch: | ||
| description: "GFX architecture used for the manylinux image build." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
| outputs: | ||
| package_index_url: | ||
| description: "Package index URL for the multi-arch release bucket." | ||
|
|
@@ -79,10 +88,6 @@ on: | |
| description: "JAX / rocm-jax release ref." | ||
| type: string | ||
| required: true | ||
| build_jaxlib: | ||
| description: "Whether this matrix cell should build jaxlib from JAX source." | ||
| type: boolean | ||
| required: true | ||
| rocm_version: | ||
| description: "ROCm package version to install and build against." | ||
| type: string | ||
|
|
@@ -103,6 +108,19 @@ on: | |
| description: "Branch, tag, or SHA to checkout. Defaults to the triggering ref." | ||
| type: string | ||
| default: "" | ||
| jax_repository: | ||
| description: "Repository containing the JAX release branch." | ||
| type: string | ||
| required: true | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a default here, or at least mention which repositories are expected in the description? What syntax is this expecting, |
||
| build_mode: | ||
| description: "Build mode: native or manylinux." | ||
| type: string | ||
| required: true | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: If there are only two expected string values, for workflow_dispatch you could use a "choice" input type to make this easier to use correctly. |
||
| gfx_arch: | ||
| description: "GFX architecture used for the manylinux image build." | ||
| type: string | ||
| required: false | ||
| default: "" | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an example to the description? We have a few syntax styles ( |
||
|
|
||
| run-name: Build Multi-Arch Linux JAX Wheels (${{ inputs.release_type }}, ${{ inputs.rocm_version }}, Multiarch) | ||
|
|
||
|
|
@@ -121,7 +139,7 @@ jobs: | |
| jax_plugin_version: ${{ steps.write_jax_versions.outputs.jax_plugin_version }} | ||
| jax_pjrt_version: ${{ steps.write_jax_versions.outputs.jax_pjrt_version }} | ||
| env: | ||
| PACKAGE_DIST_DIR: ${{ github.workspace }}/rocm-jax/jax_rocm_plugin/wheelhouse | ||
| MANYLINUX_IMAGE_TAG: therock-jax-manylinux:${{ inputs.jax_ref }}-${{ inputs.python_version }} | ||
|
|
||
| steps: | ||
| - name: Checkout TheRock | ||
|
|
@@ -131,18 +149,19 @@ jobs: | |
| ref: ${{ inputs.ref || github.ref_name }} | ||
|
|
||
| - name: Checkout rocm-jax | ||
| if: ${{ inputs.build_mode == 'native' }} | ||
| uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 | ||
| with: | ||
| path: rocm-jax | ||
| repository: ROCm/rocm-jax | ||
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
| - name: Checkout JAX | ||
| if: ${{ inputs.build_jaxlib }} | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3 | ||
| with: | ||
| path: jax | ||
| repository: ROCm/jax | ||
| path: jax-source | ||
| repository: ${{ inputs.jax_repository }} | ||
| ref: ${{ inputs.jax_ref }} | ||
|
|
||
| - name: Configure Git Identity | ||
|
|
@@ -159,17 +178,16 @@ jobs: | |
| run: | | ||
| python -m pip install boto3 | ||
|
|
||
| - name: Determine build arguments | ||
| env: | ||
| BUILD_JAXLIB: ${{ inputs.build_jaxlib }} | ||
| - name: Set package dist dir | ||
| run: | | ||
| if [ "$BUILD_JAXLIB" = "true" ]; then | ||
| echo "SOURCE_ARG=--jax-source-dir=$GITHUB_WORKSPACE/jax" >> "$GITHUB_ENV" | ||
| if [ "${{ inputs.build_mode }}" = "manylinux" ]; then | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/jax-source/dist" >> "$GITHUB_ENV" | ||
| else | ||
| echo "SOURCE_ARG=" >> "$GITHUB_ENV" | ||
| echo "PACKAGE_DIST_DIR=${GITHUB_WORKSPACE}/rocm-jax/jax_rocm_plugin/wheelhouse" >> "$GITHUB_ENV" | ||
| fi | ||
|
|
||
| - name: Build JAX Wheels | ||
| - name: Build JAX Wheels (native, v0.9.1) | ||
| if: ${{ inputs.build_mode == 'native' }} | ||
| working-directory: rocm-jax | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
|
|
@@ -181,9 +199,77 @@ jobs: | |
| --python-versions="${PYTHON_VERSION}" \ | ||
| --rocm-version="${ROCM_VERSION}" \ | ||
| --therock-path="${TAR_URL}" \ | ||
| ${SOURCE_ARG} \ | ||
| dist_wheels | ||
|
|
||
| - name: Build manylinux image | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| BRANCH: ${{ inputs.jax_ref }} | ||
| THEROCK_INDEX_URL: https://rocm.nightlies.amd.com/whl-multi-arch/ | ||
|
erman-gurses marked this conversation as resolved.
Outdated
erman-gurses marked this conversation as resolved.
Outdated
|
||
| THEROCK_VERSION: "" | ||
| GFX_ARCH: ${{ inputs.gfx_arch }} | ||
| run: | | ||
| curl -fsSL \ | ||
| -o Dockerfile.jax-manylinux_2_28-therock \ | ||
| "https://raw.githubusercontent.com/ROCm/rocm-jax/${BRANCH}/docker/manylinux/Dockerfile.jax-manylinux_2_28-therock" | ||
|
erman-gurses marked this conversation as resolved.
Outdated
|
||
|
|
||
| docker build \ | ||
| -t "${MANYLINUX_IMAGE_TAG}" \ | ||
| --file=Dockerfile.jax-manylinux_2_28-therock \ | ||
| --build-arg=THEROCK_INDEX_URL="${THEROCK_INDEX_URL}" \ | ||
| --build-arg=THEROCK_VERSION="${THEROCK_VERSION}" \ | ||
| --build-arg=GFX_ARCH="${GFX_ARCH}" \ | ||
|
erman-gurses marked this conversation as resolved.
|
||
| . | ||
|
|
||
| - name: Compute wheel version suffix | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| run: | | ||
| ROCM_VERSION_BASE="${ROCM_VERSION%%+*}" | ||
| ROCM_VERSION_BUILD="${ROCM_VERSION#*+}" | ||
|
|
||
| if [ "${RELEASE_TYPE}" = "nightly" ]; then | ||
| echo "ML_WHEEL_VERSION_SUFFIX=+rocm${ROCM_VERSION_BASE}a$(date -u +%Y%m%d)" >> "$GITHUB_ENV" | ||
| else | ||
| echo "ML_WHEEL_VERSION_SUFFIX=+rocm${ROCM_VERSION_BASE}.${ROCM_VERSION_BUILD}" >> "$GITHUB_ENV" | ||
| fi | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
| RELEASE_TYPE: ${{ inputs.release_type }} | ||
|
erman-gurses marked this conversation as resolved.
Outdated
|
||
|
|
||
| - name: Build JAX Wheels in manylinux container | ||
| if: ${{ inputs.build_mode == 'manylinux' }} | ||
| working-directory: jax-source | ||
| env: | ||
| ROCM_VERSION: ${{ inputs.rocm_version }} | ||
| PYTHON_VERSION: ${{ inputs.python_version }} | ||
| TAR_URL: ${{ inputs.tar_url }} | ||
| run: | | ||
| docker run --rm \ | ||
|
Comment on lines
+238
to
+245
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you also update the docs at https://github.com/ROCm/TheRock/tree/main/external-builds/jax#build-instructions ? (The general progression for development work and feature enablement should be local builds --> dev builds --> release builds, this jumps straight to release builds in github actions workflows which is the hardest to debug) |
||
| --user root \ | ||
| --env ROCM_VERSION="${ROCM_VERSION}" \ | ||
| --env PYTHON_VERSION="${PYTHON_VERSION}" \ | ||
| --env TAR_URL="${TAR_URL}" \ | ||
| --env ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \ | ||
| --env PACKAGE_DIST_DIR="/workspace/jax-source/dist" \ | ||
| --volume "${GITHUB_WORKSPACE}:/workspace" \ | ||
| --workdir /workspace/jax-source \ | ||
| "${MANYLINUX_IMAGE_TAG}" \ | ||
| bash -lc ' | ||
| python build/build.py build \ | ||
| --wheels=jax-rocm-plugin,jax-rocm-pjrt \ | ||
| --python_version="${PYTHON_VERSION}" \ | ||
| --bazel_startup_options=--bazelrc=build/rocm/rocm.bazelrc \ | ||
| --bazel_options=--config=rocm_release_wheel \ | ||
| --bazel_options=--repo_env=ROCM_PATH=$(rocm-sdk path --root) \ | ||
|
erman-gurses marked this conversation as resolved.
Outdated
|
||
| --bazel_options=--repo_env=ML_WHEEL_TYPE=release \ | ||
| --bazel_options=--repo_env=ML_WHEEL_VERSION_SUFFIX="${ML_WHEEL_VERSION_SUFFIX}" \ | ||
| --bazel_options=--//jaxlib/tools:jaxlib_git_hash=$(git rev-parse HEAD) \ | ||
| --verbose \ | ||
| --detailed_timestamped_log \ | ||
| --output_path=$(pwd)/dist | ||
| ' | ||
|
|
||
| - name: Extract JAX versions from built wheels | ||
| id: write_jax_versions | ||
| run: | | ||
|
|
@@ -203,7 +289,6 @@ jobs: | |
| --release-type="${{ inputs.release_type }}" | ||
|
|
||
|
|
||
|
|
||
| generate_target_to_run: | ||
| name: Generate target_to_run | ||
| if: ${{ inputs.test_amdgpu_family != '' }} | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These
repositoryandrefinputs aren't commonly provided when triggered by developers, so I'd prefer to keep them as the last inputs in the list. Looks like a few other workflows have similarly started adding inputs beyond them though.I would move
jax_repository,build_mode, andgfx_archnear the top of the inputs list since they will be set frequently.