Hi @Augustin-Zidek
It was lovely and very helpful to meet you guys on Thursday :)
As you suggested and in order to succeed with my 5746 token jobs, I threw all the additional flags at it (with the help of Claude, as this all way beyond my skills and knowledge), and I still run into ressource_exhausted.
#!/bin/bash
#SBATCH --job-name=AF3-4GPU
#SBATCH --mail-type=END,FAIL
#SBATCH --mail-user=xxx@crick.ac.uk
#SBATCH --partition=gh100
#SBATCH --reservation=h100
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=100
#SBATCH --mem=0
#SBATCH --time=72:00:00
#SBATCH --output=sbatch%j.log
ml purge
ml Singularity
# Create JAX cache directory
mkdir -p /xxx/jax_cache
chmod 777 /xxx/jax_cache
# Create temporary directory for model_config
TEMP_DIR=$(mktemp -d)
chmod 777 $TEMP_DIR
# Create model_config.py with documented sharding strategy
cat > $TEMP_DIR/model_config.py << 'EOL'
from typing import Sequence
from typing_extensions import TypeAlias
_Shape2DType: TypeAlias = tuple[int | None, int | None]
pair_transition_shard_spec: Sequence[_Shape2DType] = (
(2048, None),
(3072, 1024),
(None, 512),
)
EOL
# Settings from documentation
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export TF_FORCE_UNIFIED_MEMORY=true
export XLA_CLIENT_MEM_FRACTION=3.2
export XLA_FLAGS="--xla_gpu_enable_triton_gemm=false"
singularity exec \
--nv \
--bind /xxx/:/root/af_input \
--bind /xxx/:/root/af_output \
--bind /flask/reference/Alphafold3_dataset/model_parameters:/root/models \
--bind /flask/reference/Alphafold3_dataset/datasets:/root/public_databases \
--bind /xxx/jax_cache:/root/jax_cache \
--bind $TEMP_DIR/model_config.py:/app/alphafold/alphafold3/model/model_config.py \
/flask/apps/containers/Alphafold/3.0.0/alphafold3.sif \
python /app/alphafold/run_alphafold.py \
--json_path=/root/af_input/fold_input.json \
--model_dir=/root/models \
--db_dir=/root/public_databases \
--output_dir=/root/af_output \
--buckets 5746 \
--jax_compilation_cache_dir=/root/jax_cache
# Cleanup temporary directory
rm -rf $TEMP_DIR
It runs for a few hours and returns
Running model inference for seed 1...
Traceback (most recent call last):
File "/app/alphafold/run_alphafold.py", line 699, in <module>
app.run(main)
File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/alphafold3_venv/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/app/alphafold/run_alphafold.py", line 684, in main
process_fold_input(
File "/app/alphafold/run_alphafold.py", line 556, in process_fold_input
all_inference_results = predict_structure(
^^^^^^^^^^^^^^^^^^
File "/app/alphafold/run_alphafold.py", line 373, in predict_structure
result = model_runner.run_inference(example, rng_key)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/app/alphafold/run_alphafold.py", line 311, in run_inference
result = self._model(rng_key, featurised_example)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 84554026840 bytes.
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Any insights?
Cheers!
Celine
Hi @Augustin-Zidek
It was lovely and very helpful to meet you guys on Thursday :)
As you suggested and in order to succeed with my 5746 token jobs, I threw all the additional flags at it (with the help of Claude, as this all way beyond my skills and knowledge), and I still run into ressource_exhausted.
It runs for a few hours and returns
Any insights?
Cheers!
Celine