|
| 1 | ++++ |
| 2 | +title = "JAX on TPU Guide" |
| 3 | +description = "How to run JAX on Kubernetes with Kubeflow Trainer on Cloud TPU" |
| 4 | +weight = 16 |
| 5 | ++++ |
| 6 | + |
| 7 | +This guide describes how to use TrainJob to train or fine-tune AI models with |
| 8 | +[JAX](https://jax.readthedocs.io/) on Cloud TPU on Google Kubernetes Engine (GKE). |
| 9 | + |
| 10 | +--- |
| 11 | + |
| 12 | +## Prerequisites |
| 13 | + |
| 14 | +Before exploring this guide, make sure to follow: |
| 15 | +- [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/) |
| 16 | +- [GKE Cloud TPU documentation](https://cloud.google.com/kubernetes-engine/docs/concepts/tpus) to set up a GKE cluster with TPU nodes. For example, for an autopilot GKE cluster, you can create a TPU custom ComputeClass like |
| 17 | +``` |
| 18 | +apiVersion: cloud.google.com/v1 |
| 19 | +kind: ComputeClass |
| 20 | +metadata: |
| 21 | + name: tpu-multihost-v5-8 |
| 22 | +spec: |
| 23 | + priorities: |
| 24 | + - tpu: |
| 25 | + type: tpu-v5-lite-podslice |
| 26 | + count: 4 |
| 27 | + topology: 2x4 |
| 28 | + nodePoolAutoCreation: |
| 29 | + enabled: true |
| 30 | +``` |
| 31 | + |
| 32 | +--- |
| 33 | + |
| 34 | +## JAX on TPU Overview |
| 35 | + |
| 36 | +JAX on TPU requires a different runtime environment than GPU. Specifically: |
| 37 | +- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`). |
| 38 | +- **Resources**: You must request `google.com/tpu` resources. |
| 39 | +- **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes. |
| 40 | +- **Environment Variables**: Some TPU-specific JAX environment variables might be required depending on your JAX version. |
| 41 | + |
| 42 | +{{% alert title="Note" color="info" %}} |
| 43 | +The built-in `jax-distributed` runtime is optimized for GPUs. For TPU workloads, you can override the runtime configuration using the Python SDK. |
| 44 | +{{% /alert %}} |
| 45 | + |
| 46 | +--- |
| 47 | + |
| 48 | +## JAX Distributed Environment on TPU |
| 49 | + |
| 50 | +Your training script must explicitly initialize the JAX distributed runtime. |
| 51 | + |
| 52 | + |
| 53 | +```python |
| 54 | +from kubeflow.trainer import TrainerClient, CustomTrainer |
| 55 | +from kubeflow.trainer.options import kubernetes as k8s_options |
| 56 | + |
| 57 | +def get_jax_tpu_dist(): |
| 58 | + import os |
| 59 | + import jax |
| 60 | + import jax.distributed as dist |
| 61 | + |
| 62 | + # Initialize distributed JAX. |
| 63 | + dist.initialize( |
| 64 | + coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"], |
| 65 | + num_processes=int(os.environ["JAX_NUM_PROCESSES"]), |
| 66 | + process_id=int(os.environ["JAX_PROCESS_ID"]), |
| 67 | + ) |
| 68 | + |
| 69 | + print("JAX Distributed Environment on TPU") |
| 70 | + print(f"Local devices: {jax.local_devices()}") |
| 71 | + print(f"Global device count: {jax.device_count()}") |
| 72 | + print(f"Process index: {jax.process_index()}") |
| 73 | + |
| 74 | + import jax.numpy as jnp |
| 75 | + |
| 76 | + # Use local_device_count for the leading axis of the local input to pmap |
| 77 | + x = jnp.ones((jax.local_device_count(),)) |
| 78 | + |
| 79 | + # Pass process_index as a sharded argument to ensure SPMD consistency across all processes in the distributed job. |
| 80 | + p_idx = jnp.array([jax.process_index()] * jax.local_device_count()) |
| 81 | + |
| 82 | + y = jax.pmap(lambda v, p: v * p)(x, p_idx) |
| 83 | + |
| 84 | + print("PMAP result:", y) |
| 85 | + |
| 86 | +client = TrainerClient() |
| 87 | + |
| 88 | +# Define TPU Node Selectors and Tolerations |
| 89 | +# Replace with your GKE TPU configuration |
| 90 | +node_selector = { |
| 91 | + "cloud.google.com/compute-class": "tpu-multihost-v5-8", |
| 92 | + "cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice", |
| 93 | + "cloud.google.com/gke-tpu-topology": "2x4", |
| 94 | +} |
| 95 | + |
| 96 | +job_patch = k8s_options.RuntimePatch( |
| 97 | + training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch( |
| 98 | + template=k8s_options.JobSetTemplatePatch( |
| 99 | + spec=k8s_options.JobSetSpecPatch( |
| 100 | + replicated_jobs=[ |
| 101 | + k8s_options.ReplicatedJobPatch( |
| 102 | + name="node", |
| 103 | + template=k8s_options.JobTemplatePatch( |
| 104 | + spec=k8s_options.JobSpecPatch( |
| 105 | + template=k8s_options.PodTemplatePatch( |
| 106 | + spec=k8s_options.PodSpecPatch( |
| 107 | + node_selector=node_selector, |
| 108 | + tolerations=[ |
| 109 | + { |
| 110 | + "key": "google.com/tpu", |
| 111 | + "operator": "Exists", |
| 112 | + "effect": "NoSchedule", |
| 113 | + }, |
| 114 | + { |
| 115 | + "key": "cloud.google.com/compute-class", |
| 116 | + "operator": "Exists", |
| 117 | + "effect": "NoSchedule", |
| 118 | + }, |
| 119 | + ], |
| 120 | + ) |
| 121 | + ) |
| 122 | + ) |
| 123 | + ) |
| 124 | + ) |
| 125 | + ] |
| 126 | + ) |
| 127 | + ) |
| 128 | + ) |
| 129 | +) |
| 130 | + |
| 131 | +# Create TrainJob |
| 132 | +job_id = client.train( |
| 133 | + runtime="jax-distributed", |
| 134 | + trainer=CustomTrainer( |
| 135 | + func=get_jax_tpu_dist, |
| 136 | + image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest", |
| 137 | + num_nodes=2, |
| 138 | + resources_per_node={ |
| 139 | + "google.com/tpu": 4, |
| 140 | + }, |
| 141 | + env={ |
| 142 | + "JAX_PLATFORMS": "tpu,cpu", |
| 143 | + "ENABLE_PJRT_COMPATIBILITY": "true", |
| 144 | + } |
| 145 | + ), |
| 146 | + options=[job_patch], |
| 147 | +) |
| 148 | + |
| 149 | +# Wait until completion |
| 150 | +client.wait_for_job_status(job_id) |
| 151 | + |
| 152 | +# Logs are aggregated from node-0 |
| 153 | +print("\n".join(client.get_job_logs(name=job_id))) |
| 154 | +``` |
| 155 | + |
| 156 | +--- |
| 157 | + |
| 158 | +## End-to-end Training Example |
| 159 | + |
| 160 | +The following example demonstrates how to train a simple CNN on the MNIST dataset using JAX on multihost TPUs. |
| 161 | + |
| 162 | +```python |
| 163 | +from kubeflow.trainer import TrainerClient, CustomTrainer |
| 164 | +from kubeflow.trainer.options import kubernetes as k8s_options |
| 165 | + |
| 166 | +def train_mnist_jax(): |
| 167 | + import os |
| 168 | + import jax |
| 169 | + import jax.numpy as jnp |
| 170 | + import jax.distributed as dist |
| 171 | + import optax |
| 172 | + from flax import linen as nn |
| 173 | + from flax.training import train_state |
| 174 | + import tensorflow_datasets as tfds |
| 175 | + import tensorflow as tf |
| 176 | + |
| 177 | + # Initialize distributed JAX |
| 178 | + dist.initialize( |
| 179 | + coordinator_address=os.environ["JAX_COORDINATOR_ADDRESS"], |
| 180 | + num_processes=int(os.environ["JAX_NUM_PROCESSES"]), |
| 181 | + process_id=int(os.environ["JAX_PROCESS_ID"]), |
| 182 | + ) |
| 183 | + |
| 184 | + # Prevent TF from grabbing TPU |
| 185 | + tf.config.set_visible_devices([], 'TPU') |
| 186 | + |
| 187 | + process_index = jax.process_index() |
| 188 | + local_device_count = jax.local_device_count() |
| 189 | + |
| 190 | + print(f"Process: {process_index}") |
| 191 | + print(f"Global devices: {jax.device_count()}") |
| 192 | + print(f"Local devices: {jax.local_devices()}") |
| 193 | + |
| 194 | + # Model definition |
| 195 | + class CNN(nn.Module): |
| 196 | + @nn.compact |
| 197 | + def __call__(self, x): |
| 198 | + x = nn.Conv(features=32, kernel_size=(3, 3))(x) |
| 199 | + x = nn.relu(x) |
| 200 | + x = nn.avg_pool(x, (2, 2), (2, 2)) |
| 201 | + x = x.reshape((x.shape[0], -1)) |
| 202 | + x = nn.Dense(128)(x) |
| 203 | + x = nn.relu(x) |
| 204 | + x = nn.Dense(10)(x) |
| 205 | + return x |
| 206 | + |
| 207 | + # Dataset sharding: |
| 208 | + # In JAX's SPMD model, each process runs the same code but should handle |
| 209 | + # different data to increase throughput. Without sharding, every node |
| 210 | + # would process the same images, wasting compute. |
| 211 | + ds = tfds.load("mnist", split="train", as_supervised=True) |
| 212 | + ds = ds.shard(num_shards=jax.process_count(), index=process_index) |
| 213 | + |
| 214 | + def preprocess(image, label): |
| 215 | + image = tf.cast(image, tf.float32) / 255.0 |
| 216 | + return image, label |
| 217 | + |
| 218 | + ds = ds.map(preprocess).batch(128).prefetch(1) |
| 219 | + ds = tfds.as_numpy(ds) |
| 220 | + |
| 221 | + # Training setup |
| 222 | + model = CNN() |
| 223 | + rng = jax.random.PRNGKey(0) |
| 224 | + |
| 225 | + params = model.init(rng, jnp.ones([1, 28, 28, 1]))["params"] |
| 226 | + |
| 227 | + tx = optax.adam(1e-3) |
| 228 | + |
| 229 | + state = train_state.TrainState.create( |
| 230 | + apply_fn=model.apply, |
| 231 | + params=params, |
| 232 | + tx=tx, |
| 233 | + ) |
| 234 | + |
| 235 | + # replicate state across local devices |
| 236 | + state = jax.device_put_replicated(state, jax.local_devices()) |
| 237 | + |
| 238 | + # Training step |
| 239 | + def loss_fn(params, batch): |
| 240 | + images, labels = batch |
| 241 | + logits = model.apply({"params": params}, images) |
| 242 | + onehot = jax.nn.one_hot(labels, 10) |
| 243 | + loss = optax.softmax_cross_entropy(logits, onehot).mean() |
| 244 | + return loss |
| 245 | + |
| 246 | + grad_fn = jax.value_and_grad(loss_fn) |
| 247 | + |
| 248 | + def train_step(state, batch): |
| 249 | + loss, grads = grad_fn(state.params, batch) |
| 250 | + # Average gradients across all devices. |
| 251 | + # We must bind "batch" axis in jax.pmap for pmean to work. |
| 252 | + grads = jax.lax.pmean(grads, axis_name="batch") |
| 253 | + state = state.apply_gradients(grads=grads) |
| 254 | + return state, loss |
| 255 | + |
| 256 | + train_step = jax.pmap(train_step, axis_name="batch") |
| 257 | + |
| 258 | + # Training loop |
| 259 | + for epoch in range(5): |
| 260 | + for images, labels in ds: |
| 261 | + # Convert to jnp and shard batch per local device |
| 262 | + images = jnp.array(images).reshape( |
| 263 | + (local_device_count, -1, 28, 28, 1) |
| 264 | + ) |
| 265 | + labels = jnp.array(labels).reshape( |
| 266 | + (local_device_count, -1) |
| 267 | + ) |
| 268 | + |
| 269 | + state, loss = train_step(state, (images, labels)) |
| 270 | + |
| 271 | + if process_index == 0: |
| 272 | + print(f"Epoch {epoch}, Loss: {loss.mean()}") |
| 273 | + |
| 274 | +client = TrainerClient() |
| 275 | + |
| 276 | +# Define TPU Node Selectors and Tolerations |
| 277 | +node_selector = { |
| 278 | + "cloud.google.com/compute-class": "tpu-multihost-v5-8", |
| 279 | + "cloud.google.com/gke-tpu-accelerator": "tpu-v5-lite-podslice", |
| 280 | + "cloud.google.com/gke-tpu-topology": "2x4", |
| 281 | +} |
| 282 | + |
| 283 | +job_patch = k8s_options.RuntimePatch( |
| 284 | + training_runtime_spec=k8s_options.TrainingRuntimeSpecPatch( |
| 285 | + template=k8s_options.JobSetTemplatePatch( |
| 286 | + spec=k8s_options.JobSetSpecPatch( |
| 287 | + replicated_jobs=[ |
| 288 | + k8s_options.ReplicatedJobPatch( |
| 289 | + name="node", |
| 290 | + template=k8s_options.JobTemplatePatch( |
| 291 | + spec=k8s_options.JobSpecPatch( |
| 292 | + template=k8s_options.PodTemplatePatch( |
| 293 | + spec=k8s_options.PodSpecPatch( |
| 294 | + node_selector=node_selector, |
| 295 | + tolerations=[ |
| 296 | + { |
| 297 | + "key": "google.com/tpu", |
| 298 | + "operator": "Exists", |
| 299 | + "effect": "NoSchedule", |
| 300 | + }, |
| 301 | + { |
| 302 | + "key": "cloud.google.com/compute-class", |
| 303 | + "operator": "Exists", |
| 304 | + "effect": "NoSchedule", |
| 305 | + }, |
| 306 | + ], |
| 307 | + ) |
| 308 | + ) |
| 309 | + ) |
| 310 | + ) |
| 311 | + ) |
| 312 | + ] |
| 313 | + ) |
| 314 | + ) |
| 315 | + ) |
| 316 | +) |
| 317 | + |
| 318 | +job_id = client.train( |
| 319 | + runtime="jax-distributed", |
| 320 | + trainer=CustomTrainer( |
| 321 | + func=train_mnist_jax, |
| 322 | + image="us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:latest", |
| 323 | + num_nodes=2, |
| 324 | + resources_per_node={ |
| 325 | + "google.com/tpu": 4, |
| 326 | + }, |
| 327 | + env={ |
| 328 | + "JAX_PLATFORMS": "tpu,cpu", |
| 329 | + "ENABLE_PJRT_COMPATIBILITY": "true", |
| 330 | + }, |
| 331 | + packages_to_install=[ |
| 332 | + "tensorflow-datasets", |
| 333 | + "flax", |
| 334 | + "optax", |
| 335 | + "tensorflow", |
| 336 | + ], |
| 337 | + ), |
| 338 | + options=[job_patch], |
| 339 | +) |
| 340 | + |
| 341 | +client.wait_for_job_status(job_id) |
| 342 | +print("\n".join(client.get_job_logs(name=job_id))) |
| 343 | +``` |
| 344 | + |
| 345 | +--- |
| 346 | + |
| 347 | +## TPU Specific Configurations |
| 348 | + |
| 349 | +### Node Selectors and Topology |
| 350 | + |
| 351 | +When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels: |
| 352 | + |
| 353 | +| Label | Example Value | |
| 354 | +|-------|---------------| |
| 355 | +| `cloud.google.com/compute-class` | `tpu-class` | |
| 356 | +| `cloud.google.com/gke-tpu-accelerator` | `tpu-v5-lite-podslice` | |
| 357 | +| `cloud.google.com/gke-tpu-topology` | `2x4` | |
| 358 | + |
| 359 | +### Environment Variables |
| 360 | + |
| 361 | +| Variable | Description | |
| 362 | +|----------|-------------| |
| 363 | +| `JAX_PLATFORMS` | Set to `tpu,cpu` to ensure JAX uses the TPU backend. | |
| 364 | +| `ENABLE_PJRT_COMPATIBILITY` | Set to `true` for compatibility with newer JAX/LibTPU versions. | |
| 365 | + |
| 366 | +--- |
| 367 | + |
| 368 | +## Next Steps |
| 369 | + |
| 370 | +- Learn more about [JAX distributed training](https://jax.readthedocs.io/en/latest/jax.distributed.html). |
| 371 | +- Explore [GKE Cloud TPU best practices](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus). |
0 commit comments