Date: November 14-15, 2025 Hardware: NVIDIA DGX Spark with GB10 Blackwell GPU (Compute Capability sm_121a/sm_121) Goal: Get AlphaFold 3 running on Blackwell architecture
Successfully configured AlphaFold 3 to run on NVIDIA's GB10 Blackwell GPU after resolving multiple compatibility issues across the Triton/JAX/CUDA stack. The main challenges were:
AlphaFold 3's pyproject.toml specified Triton 3.3.1, which uses LLVM 15. This LLVM version predates Blackwell GPU support.
'sm_121a' is not a recognized processor for this target
Changed Dockerfile to build Triton from main branch instead of release/3.3.x:
# Line 77 in docker/Dockerfile
git checkout main # Changed from: git checkout release/3.3.x
Key Learning: Hardware architecture support requires matching LLVM version, not just CUDA version.
Even after upgrading to Triton main, still got architecture recognition errors.
LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.cp.async.commit.group
Triton was generating sm_121a but should generate sm_121. The 'a' suffix logic was:
suffix = "a" if capability >= 90 else ""
This applied the 'a' suffix to ALL compute capabilities ≥90, including sm_121 (Blackwell). However, the 'a' suffix is Hopper-specific (sm_90a) and should NOT be applied to Blackwell.
Patched Triton compiler to restrict 'a' suffix to sm_90 only:
# In triton/backends/nvidia/compiler.py
suffix = "a" if capability == 90 else "" # Changed from: capability >= 90
Applied via sed in Dockerfile:
sed -i 's/suffix = "a" if capability >= 90 else ""/suffix = "a" if capability == 90 else ""/' \
third_party/nvidia/backend/compiler.py
Key Learning: Architecture suffixes are generation-specific, not monotonic across compute capabilities.
AlphaFold 3's triton_utils.py used _builder parameter API, incompatible with Triton main.
ValueError: Did you forget to add @triton.jit ? (_semantic argument must be provided outside of JIT functions.)
_semantic API for Triton 3.5+_builder for Triton 3.3.x compatibility_semantic parameter instead of _builder_builder to maintain Triton 3.3.x compatibilityModified AlphaFold 3's triton_utils.py to use _semantic API:
# src/alphafold3/jax/common/triton_utils.py
def _dot_fn(
a: tl.core.tensor,
b: tl.core.tensor,
*,
trans_a: bool = False,
trans_b: bool = False,
_semantic, # Changed from _builder
):
# ... all tl.static_assert and tl operations now use _semantic=_semantic
Key Learning: Major version upgrades often require API migrations; check for reverted commits in dependencies.
jax-triton 0.3.0 incompatible with Triton main's API changes.
TypeError: CUDABackend.make_ttir() missing 1 required positional argument: 'capability'
CUDABackend.make_ttir() without capability argumentcapability parametercapability argumentUpdated pyproject.toml and dev-requirements.txt to use specific working commit:
# pyproject.toml
"jax-triton @ git+https://github.com/jax-ml/jax-triton.git@0c6c888"
# dev-requirements.txt
git+https://github.com/jax-ml/jax-triton.git@0c6c888
Key Learning: When upgrading major dependencies, intermediate library versions may be incompatible; pinning to specific commits can bridge compatibility gaps.
Triton's bundled PTXAS assembler doesn't recognize sm_121a architecture.
ptxas fatal: PTX with .target 'sm_121a' cannot be compiled for architecture 'sm_121'
triton/backends/nvidia/bin/ptxasTRITON_PTXAS_PATH would solve itSet environment variable to use newer PTXAS:
export TRITON_PTXAS_PATH=/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas
Added to Dockerfile:
ENV TRITON_PTXAS_PATH="/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas"
Key Learning: Compiler toolchains include multiple binaries (compiler + assembler); version mismatches can occur at any stage.
capability >= 90 to capability == 90# Line 77: Use Triton main instead of 3.3.x
git checkout main
# Line 79: Apply suffix patch
sed -i 's/suffix = "a" if capability >= 90 else ""/suffix = "a" if capability == 90 else ""/' \
third_party/nvidia/backend/compiler.py
# Line 108: Set PTXAS path
ENV TRITON_PTXAS_PATH="/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas"
dependencies = [
"absl-py",
"chex",
"dm-haiku==0.0.14",
"dm-tree",
"jax==0.6.0",
"jax[cuda12]==0.6.0",
"jax-triton @ git+https://github.com/jax-ml/jax-triton.git@0c6c888", # Updated
# triton==3.3.1 removed - built from source
"jaxtyping==0.3.2",
"numpy",
"rdkit==2024.3.5",
"tqdm",
"typeguard==2.13.3",
"zstandard",
]
Modified src/alphafold3/jax/common/triton_utils.py:
_builder parameters to _semantic_semantic=_semanticLines 68-82: Updated dot function and all tl.* operations
What we tried: Following GitHub issue #394 instructions for Triton 3.3.1 Why it failed: Issue was for GH200 (Hopper/sm_90), not GB10 (Blackwell/sm_121) Lesson: Always verify hardware architecture matches the compatibility guide
What we tried: Initially focused on LLVM intrinsic errors Why it failed: Didn't realize 'a' suffix was being incorrectly applied Lesson: Architecture naming conventions have specific meanings; investigate suffix logic
What we tried: Upgraded to jax-triton from main branch
Why it failed: Main branch uses native_specialize_impl that doesn't exist in Triton 3.5.x
Lesson: Library main branches may target unreleased dependency versions
What we tried: Set PTXAS environment variable without patching suffix logic Why it failed: Triton was still generating sm_121a which even newer PTXAS couldn't handle correctly Lesson: Environment variables alone can't fix code generation issues
docker run --name af3_test_gb10 --rm \
--volume /home/ruh/research/PhD/related_work/alphafold3/af_input:/root/af_input \
--volume /home/ruh/research/PhD/related_work/alphafold3/af_output:/root/af_output \
--volume /home/ruh/data/af3_model:/root/models \
--volume /home/ruh/data/af_public_databases:/root/public_databases \
--volume /home/ruh/data/af3_jax_cache:/root/jax_cache \
--gpus all --memory=64g \
-e XLA_PYTHON_CLIENT_PREALLOCATE=false \
-e TF_FORCE_UNIFIED_MEMORY=true \
-e XLA_CLIENT_MEM_FRACTION=3.2 \
alphafold3 python run_alphafold_test.py
TRITON_PTXAS_PATH environment variable in containerGetting AlphaFold 3 running on GB10 Blackwell required navigating a complex dependency chain:
CUDA 12.9 → Triton main (with suffix patch) → jax-triton 0c6c888 → AlphaFold 3 (with _semantic API)
The main challenges were:
Total time investment: ~2-3 hours of debugging across multiple GitHub issues and commit histories.
The experience highlights the challenges of running scientific software on cutting-edge hardware where the software ecosystem is still catching up to hardware releases.
After fixing compilation issues, getting runtime CUDA error during kernel execution.
CUDA_ERROR_INVALID_VALUE
operation gpuLaunchKernel(...) failed
Fatal Python error: Segmentation fault
Triton kernels may be using launch parameters (block dimensions, grid dimensions, or shared memory) that are incompatible with GB10 Blackwell architecture limits or have bugs specific to sm_121.
⚠️ BLOCKED ON UPSTREAM - Compilation works, runtime blocked by incomplete JAX Blackwell support
The kernel launch failure is due to incomplete Blackwell support in JAX/jaxlib, not Triton:
sm_121a → sm_121 fixedCUDA_ERROR_INVALID_VALUE kernel launch failurenative_specialize_impl missing (requires newer Triton API)Status: ⚠️ BLOCKED - Waiting for upstream JAX Blackwell support Date: November 15, 2025 Blocking Issues:
What Works:
What Doesn't Work:
CUDA_ERROR_INVALID_VALUENext Steps (When Unblocked):
Estimated Timeline: