Skip to content

5746 tokens job: RESOURCE_EXHAUSTED (H100 80G, despite unified memory, jax flag and increase bucket size) #341

@GitCeliniHub

Description

@GitCeliniHub

Hi @Augustin-Zidek

It was lovely and very helpful to meet you guys on Thursday :)

As you suggested and in order to succeed with my 5746 token jobs, I threw all the additional flags at it (with the help of Claude, as this all way beyond my skills and knowledge), and I still run into ressource_exhausted.

#!/bin/bash

#SBATCH --job-name=AF3-4GPU                 
#SBATCH --mail-type=END,FAIL                
#SBATCH --mail-user=xxx@crick.ac.uk     
#SBATCH --partition=gh100
#SBATCH --reservation=h100
#SBATCH --nodes=1
#SBATCH --ntasks=1                          
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=100                 
#SBATCH --mem=0                         
#SBATCH --time=72:00:00                     
#SBATCH --output=sbatch%j.log   

ml purge
ml Singularity

# Create JAX cache directory
mkdir -p /xxx/jax_cache
chmod 777 /xxx/jax_cache

# Create temporary directory for model_config
TEMP_DIR=$(mktemp -d)
chmod 777 $TEMP_DIR

# Create model_config.py with documented sharding strategy
cat > $TEMP_DIR/model_config.py << 'EOL'
from typing import Sequence
from typing_extensions import TypeAlias

_Shape2DType: TypeAlias = tuple[int | None, int | None]

pair_transition_shard_spec: Sequence[_Shape2DType] = (
    (2048, None),
    (3072, 1024),
    (None, 512),
)
EOL

# Settings from documentation
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_FORCE_UNIFIED_MEMORY=true
export XLA_CLIENT_MEM_FRACTION=3.2
export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"

singularity exec \
    --nv \
    --bind /xxx/:/root/af_input \
    --bind /xxx/:/root/af_output \
    --bind /flask/reference/Alphafold3_dataset/model_parameters:/root/models \
    --bind /flask/reference/Alphafold3_dataset/datasets:/root/public_databases \
    --bind /xxx/jax_cache:/root/jax_cache \
    --bind $TEMP_DIR/model_config.py:/app/alphafold/alphafold3/model/model_config.py \
    /flask/apps/containers/Alphafold/3.0.0/alphafold3.sif \
    python /app/alphafold/run_alphafold.py \
    --json_path=/root/af_input/fold_input.json \
    --model_dir=/root/models \
    --db_dir=/root/public_databases \
    --output_dir=/root/af_output \
    --buckets 5746 \
    --jax_compilation_cache_dir=/root/jax_cache

# Cleanup temporary directory
rm -rf $TEMP_DIR

It runs for a few hours and returns

Running model inference for seed 1...
Traceback (most recent call last):
  File "/app/alphafold/run_alphafold.py", line 699, in <module>
    app.run(main)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
             ^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 684, in main
    process_fold_input(
  File "/app/alphafold/run_alphafold.py", line 556, in process_fold_input
    all_inference_results = predict_structure(
                            ^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 373, in predict_structure
    result = model_runner.run_inference(example, rng_key)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/app/alphafold/run_alphafold.py", line 311, in run_inference
    result = self._model(rng_key, featurised_example)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 84554026840 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Any insights?
Cheers!
Celine

Metadata

Metadata

Assignees

No one assigned

    Labels

    questionFurther information is requested

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions