Skip to content

Commit 7e03585

Browse files
[trainer] add Jax trainer guide for TPU
Signed-off-by: siyuanfoundation <sizhang@google.com>
1 parent 223ca58 commit 7e03585

File tree

2 files changed

+376
-2
lines changed

2 files changed

+376
-2
lines changed
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
+++
2+
title = "JAX on TPU Guide"
3+
description = "How to run JAX on Kubernetes with Kubeflow Trainer on Cloud TPU"
4+
weight = 16
5+
+++
6+
7+
This guide describes how to use TrainJob to train or fine-tune AI models with
8+
[JAX](https://jax.readthedocs.io/) on Cloud TPU on Google Kubernetes Engine (GKE).
9+
10+
---
11+
12+
## Prerequisites
13+
14+
Before exploring this guide, make sure to follow:
15+
- [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/)
16+
- [GKE Cloud TPU documentation](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus) to set up a GKE cluster with TPU nodes. For example, for an autopilot GKE cluster, you can create a TPU custom ComputeClass like
17+
```
18+
apiVersion: cloud.google.com/v1
19+
kind: ComputeClass
20+
metadata:
21+
name: tpu-multihost-v5-8
22+
spec:
23+
priorities:
24+
- tpu:
25+
type: tpu-v5-lite-podslice
26+
count: 4
27+
topology: 2x4
28+
nodePoolAutoCreation:
29+
enabled: true
30+
```
31+
32+
---
33+
34+
## JAX on TPU Overview
35+
36+
JAX on TPU requires a different runtime environment than GPU. Specifically:
37+
- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`).
38+
- **Resources**: You must request `google.com/tpu` resources.
39+
- **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes.
40+
- **Environment Variables**: Some TPU-specific JAX environment variables might be required depending on your JAX version.
41+
42+
{{% alert title="Note" color="info" %}}
43+
The built-in `jax-distributed` runtime is optimized for GPUs. For TPU workloads, you can override the runtime configuration using the Python SDK.
44+
{{% /alert %}}
45+
46+
---
47+
48+
## JAX Distributed Environment on TPU
49+
50+
Your training script must explicitly initialize the JAX distributed runtime.
51+
52+
53+
```python
54+
from kubeflow.trainer import TrainerClient, CustomTrainer
55+
from kubeflow.trainer.options import kubernetes as k8s_options
56+
57+
def get_jax_tpu_dist():
58+
import os
59+
import jax
60+
import jax.distributed as dist
61+
62+
# Initialize distributed JAX.
63+
dist.initialize(
64+
coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"],
65+
num_processes=int(os.environ["JAX_NUM_PROCESSES"]),
66+
process_id=int(os.environ["JAX_PROCESS_ID"]),
67+
)
68+
69+
print("JAX Distributed Environment on TPU")
70+
print(f"Local devices: {jax.local_devices()}")
71+
print(f"Global device count: {jax.device_count()}")
72+
print(f"Process index: {jax.process_index()}")
73+
74+
import jax.numpy as jnp
75+
76+
# Use local_device_count for the leading axis of the local input to pmap
77+
x = jnp.ones((jax.local_device_count(),))
78+
79+
# Pass process_index as a sharded argument to ensure SPMD consistency across all processes in the distributed job.
80+
p_idx = jnp.array([jax.process_index()] * jax.local_device_count())
81+
82+
y = jax.pmap(lambda v, p: v * p)(x, p_idx)
83+
84+
print("PMAP result:", y)
85+
86+
client = TrainerClient()
87+
88+
# Define TPU Node Selectors and Tolerations
89+
# Replace with your GKE TPU configuration
90+
node_selector = {
91+
"cloud.google.com/compute-class": "tpu-multihost-v5-8",
92+
"cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice",
93+
"cloud.google.com/gke-tpu-topology": "2x4",
94+
}
95+
96+
job_patch = k8s_options.RuntimePatch(
97+
training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch(
98+
template=k8s_options.JobSetTemplatePatch(
99+
spec=k8s_options.JobSetSpecPatch(
100+
replicated_jobs=[
101+
k8s_options.ReplicatedJobPatch(
102+
name="node",
103+
template=k8s_options.JobTemplatePatch(
104+
spec=k8s_options.JobSpecPatch(
105+
template=k8s_options.PodTemplatePatch(
106+
spec=k8s_options.PodSpecPatch(
107+
node_selector=node_selector,
108+
tolerations=[
109+
{
110+
"key": "google.com/tpu",
111+
"operator": "Exists",
112+
"effect": "NoSchedule",
113+
},
114+
{
115+
"key": "cloud.google.com/compute-class",
116+
"operator": "Exists",
117+
"effect": "NoSchedule",
118+
},
119+
],
120+
)
121+
)
122+
)
123+
)
124+
)
125+
]
126+
)
127+
)
128+
)
129+
)
130+
131+
# Create TrainJob
132+
job_id = client.train(
133+
runtime="jax-distributed",
134+
trainer=CustomTrainer(
135+
func=get_jax_tpu_dist,
136+
image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest",
137+
num_nodes=2,
138+
resources_per_node={
139+
"google.com/tpu": 4,
140+
},
141+
env={
142+
"JAX_PLATFORMS": "tpu,cpu",
143+
"ENABLE_PJRT_COMPATIBILITY": "true",
144+
}
145+
),
146+
options=[job_patch],
147+
)
148+
149+
# Wait until completion
150+
client.wait_for_job_status(job_id)
151+
152+
# Logs are aggregated from node-0
153+
print("\n".join(client.get_job_logs(name=job_id)))
154+
```
155+
156+
---
157+
158+
## End-to-end Training Example
159+
160+
The following example demonstrates how to train a simple CNN on the MNIST dataset using JAX on multihost TPUs.
161+
162+
```python
163+
from kubeflow.trainer import TrainerClient, CustomTrainer
164+
from kubeflow.trainer.options import kubernetes as k8s_options
165+
166+
def train_mnist_jax():
167+
import os
168+
import jax
169+
import jax.numpy as jnp
170+
import jax.distributed as dist
171+
import optax
172+
from flax import linen as nn
173+
from flax.training import train_state
174+
import tensorflow_datasets as tfds
175+
import tensorflow as tf
176+
177+
# Initialize distributed JAX
178+
dist.initialize(
179+
coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"],
180+
num_processes=int(os.environ["JAX_NUM_PROCESSES"]),
181+
process_id=int(os.environ["JAX_PROCESS_ID"]),
182+
)
183+
184+
# Prevent TF from grabbing TPU
185+
tf.config.set_visible_devices([], 'TPU')
186+
187+
process_index = jax.process_index()
188+
local_device_count = jax.local_device_count()
189+
190+
print(f"Process: {process_index}")
191+
print(f"Global devices: {jax.device_count()}")
192+
print(f"Local devices: {jax.local_devices()}")
193+
194+
# Model definition
195+
class CNN(nn.Module):
196+
@nn.compact
197+
def __call__(self, x):
198+
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
199+
x = nn.relu(x)
200+
x = nn.avg_pool(x, (2, 2), (2, 2))
201+
x = x.reshape((x.shape[0], -1))
202+
x = nn.Dense(128)(x)
203+
x = nn.relu(x)
204+
x = nn.Dense(10)(x)
205+
return x
206+
207+
# Dataset sharding:
208+
# In JAX's SPMD model, each process runs the same code but should handle
209+
# different data to increase throughput. Without sharding, every node
210+
# would process the same images, wasting compute.
211+
ds = tfds.load("mnist", split="train", as_supervised=True)
212+
ds = ds.shard(num_shards=jax.process_count(), index=process_index)
213+
214+
def preprocess(image, label):
215+
image = tf.cast(image, tf.float32) / 255.0
216+
return image, label
217+
218+
ds = ds.map(preprocess).batch(128).prefetch(1)
219+
ds = tfds.as_numpy(ds)
220+
221+
# Training setup
222+
model = CNN()
223+
rng = jax.random.PRNGKey(0)
224+
225+
params = model.init(rng, jnp.ones([1, 28, 28, 1]))["params"]
226+
227+
tx = optax.adam(1e-3)
228+
229+
state = train_state.TrainState.create(
230+
apply_fn=model.apply,
231+
params=params,
232+
tx=tx,
233+
)
234+
235+
# replicate state across local devices
236+
state = jax.device_put_replicated(state, jax.local_devices())
237+
238+
# Training step
239+
def loss_fn(params, batch):
240+
images, labels = batch
241+
logits = model.apply({"params": params}, images)
242+
onehot = jax.nn.one_hot(labels, 10)
243+
loss = optax.softmax_cross_entropy(logits, onehot).mean()
244+
return loss
245+
246+
grad_fn = jax.value_and_grad(loss_fn)
247+
248+
def train_step(state, batch):
249+
loss, grads = grad_fn(state.params, batch)
250+
# Average gradients across all devices.
251+
# We must bind "batch" axis in jax.pmap for pmean to work.
252+
grads = jax.lax.pmean(grads, axis_name="batch")
253+
state = state.apply_gradients(grads=grads)
254+
return state, loss
255+
256+
train_step = jax.pmap(train_step, axis_name="batch")
257+
258+
# Training loop
259+
for epoch in range(5):
260+
for images, labels in ds:
261+
# Convert to jnp and shard batch per local device
262+
images = jnp.array(images).reshape(
263+
(local_device_count, -1, 28, 28, 1)
264+
)
265+
labels = jnp.array(labels).reshape(
266+
(local_device_count, -1)
267+
)
268+
269+
state, loss = train_step(state, (images, labels))
270+
271+
if process_index == 0:
272+
print(f"Epoch {epoch}, Loss: {loss.mean()}")
273+
274+
client = TrainerClient()
275+
276+
# Define TPU Node Selectors and Tolerations
277+
node_selector = {
278+
"cloud.google.com/compute-class": "tpu-multihost-v5-8",
279+
"cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice",
280+
"cloud.google.com/gke-tpu-topology": "2x4",
281+
}
282+
283+
job_patch = k8s_options.RuntimePatch(
284+
training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch(
285+
template=k8s_options.JobSetTemplatePatch(
286+
spec=k8s_options.JobSetSpecPatch(
287+
replicated_jobs=[
288+
k8s_options.ReplicatedJobPatch(
289+
name="node",
290+
template=k8s_options.JobTemplatePatch(
291+
spec=k8s_options.JobSpecPatch(
292+
template=k8s_options.PodTemplatePatch(
293+
spec=k8s_options.PodSpecPatch(
294+
node_selector=node_selector,
295+
tolerations=[
296+
{
297+
"key": "google.com/tpu",
298+
"operator": "Exists",
299+
"effect": "NoSchedule",
300+
},
301+
{
302+
"key": "cloud.google.com/compute-class",
303+
"operator": "Exists",
304+
"effect": "NoSchedule",
305+
},
306+
],
307+
)
308+
)
309+
)
310+
)
311+
)
312+
]
313+
)
314+
)
315+
)
316+
)
317+
318+
job_id = client.train(
319+
runtime="jax-distributed",
320+
trainer=CustomTrainer(
321+
func=train_mnist_jax,
322+
image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest",
323+
num_nodes=2,
324+
resources_per_node={
325+
"google.com/tpu": 4,
326+
},
327+
env={
328+
"JAX_PLATFORMS": "tpu,cpu",
329+
"ENABLE_PJRT_COMPATIBILITY": "true",
330+
},
331+
packages_to_install=[
332+
"tensorflow-datasets",
333+
"flax",
334+
"optax",
335+
"tensorflow",
336+
],
337+
),
338+
options=[job_patch],
339+
)
340+
341+
client.wait_for_job_status(job_id)
342+
print("\n".join(client.get_job_logs(name=job_id)))
343+
```
344+
345+
---
346+
347+
## TPU Specific Configurations
348+
349+
### Node Selectors and Topology
350+
351+
When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels:
352+
353+
| Label | Example Value |
354+
|-------|---------------|
355+
| `cloud.google.com/compute-class` | `tpu-multihost-v5-8` |
356+
| `cloud.google.com/gke-tpu-accelerator` | `tpu-v5-lite-podslice` |
357+
| `cloud.google.com/gke-tpu-topology` | `2x4` |
358+
359+
### Environment Variables
360+
361+
| Variable | Description |
362+
|----------|-------------|
363+
| `JAX_PLATFORMS` | Set to `tpu,cpu` to ensure JAX uses the TPU backend. |
364+
| `ENABLE_PJRT_COMPATIBILITY` | Set to `true` for compatibility with newer JAX/LibTPU versions. |
365+
366+
---
367+
368+
## Next Steps
369+
370+
- Learn more about [JAX distributed training](https://jax.readthedocs.io/en/latest/jax.distributed.html).
371+
- Explore [GKE Cloud TPU best practices](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus).

content/en/docs/components/trainer/user-guides/jax.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,14 @@ With Kubeflow Trainer, you can run:
3636
- Data-parallel and model-parallel JAX workloads
3737

3838
{{% alert title="Note" color="info" %}}
39-
The JAX runtime currently supports CPU and GPU workloads only.
39+
The default JAX runtime currently supports CPU and GPU workloads only.
4040

41-
TPU workloads are not supported because installing both `jax[cuda]`
41+
TPU workloads are not supported in the default JAX runtime because installing both `jax[cuda]`
4242
and `jax[tpu]` in the same image leads to backend and plugin conflicts.
4343
A separate TPU-specific runtime is required.
44+
45+
Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/)
46+
for more details on how to run JAX on Cloud TPU.
4447
{{% /alert %}}
4548

4649
---

0 commit comments

Comments
 (0)