AlphaFold 3 on NVIDIA GB10 Blackwell GPU - Technical Report

Date: November 14-15, 2025 Hardware: NVIDIA DGX Spark with GB10 Blackwell GPU (Compute Capability sm_121a/sm_121) Goal: Get AlphaFold 3 running on Blackwell architecture


Executive Summary

Successfully configured AlphaFold 3 to run on NVIDIA's GB10 Blackwell GPU after resolving multiple compatibility issues across the Triton/JAX/CUDA stack. The main challenges were:

  1. Triton compiler lacking sm_121 support
  2. API incompatibilities between Triton versions and jax-triton
  3. LLVM backend version mismatches
  4. PTXAS assembler not recognizing sm_121a architecture suffix

Problem 1: Initial Triton Version Incompatibility

Issue

AlphaFold 3's pyproject.toml specified Triton 3.3.1, which uses LLVM 15. This LLVM version predates Blackwell GPU support.

Error Encountered

'sm_121a' is not a recognized processor for this target

Root Cause

Investigation Path

  1. Checked GitHub issue #394 which recommended CUDA 12.8+ and Triton 3.3.1
  2. Realized issue was for GH200 (Hopper), not GB10 (Blackwell)
  3. Found Triton PR #8498 (merged Oct 24, 2025) adding sm_120/sm_121 support to main branch
  4. Discovered Triton main uses LLVM from Sept 2025 with Blackwell support

Solution

Changed Dockerfile to build Triton from main branch instead of release/3.3.x:

# Line 77 in docker/Dockerfile git checkout main # Changed from: git checkout release/3.3.x

Key Learning: Hardware architecture support requires matching LLVM version, not just CUDA version.


Problem 2: The 'a' Suffix Mystery (sm_121a vs sm_121)

Issue

Even after upgrading to Triton main, still got architecture recognition errors.

Error Encountered

LLVM ERROR: Cannot select: intrinsic %llvm.nvvm.cp.async.commit.group

Root Cause

Triton was generating sm_121a but should generate sm_121. The 'a' suffix logic was:

suffix = "a" if capability >= 90 else ""

This applied the 'a' suffix to ALL compute capabilities ≥90, including sm_121 (Blackwell). However, the 'a' suffix is Hopper-specific (sm_90a) and should NOT be applied to Blackwell.

Investigation Path

  1. Found GitHub issue triton-lang/triton#8543 discussing the suffix problem
  2. Learned 'a' suffix indicates "architecture with TMA (Tensor Memory Accelerator)"
  3. Discovered Hopper (sm_90) needs 'a', but Blackwell (sm_121) does not

Solution

Patched Triton compiler to restrict 'a' suffix to sm_90 only:

# In triton/backends/nvidia/compiler.py suffix = "a" if capability == 90 else "" # Changed from: capability >= 90

Applied via sed in Dockerfile:

sed -i 's/suffix = "a" if capability >= 90 else ""/suffix = "a" if capability == 90 else ""/' \ third_party/nvidia/backend/compiler.py

Key Learning: Architecture suffixes are generation-specific, not monotonic across compute capabilities.


Problem 3: Triton API Changes (_builder vs _semantic)

Issue

AlphaFold 3's triton_utils.py used _builder parameter API, incompatible with Triton main.

Error Encountered

ValueError: Did you forget to add @triton.jit ? (_semantic argument must be provided outside of JIT functions.)

Root Cause

Investigation Path

  1. Found GitHub issue alphafold#486 discussing the revert
  2. Traced commit history: ec4254a (upgrade) → f2edd59 (revert) → latest (still reverted)
  3. Realized AlphaFold 3 stayed on _builder to maintain Triton 3.3.x compatibility

Solution

Modified AlphaFold 3's triton_utils.py to use _semantic API:

# src/alphafold3/jax/common/triton_utils.py def _dot_fn( a: tl.core.tensor, b: tl.core.tensor, *, trans_a: bool = False, trans_b: bool = False, _semantic, # Changed from _builder ): # ... all tl.static_assert and tl operations now use _semantic=_semantic

Key Learning: Major version upgrades often require API migrations; check for reverted commits in dependencies.


Problem 4: jax-triton Version Incompatibility

Issue

jax-triton 0.3.0 incompatible with Triton main's API changes.

Error Encountered

TypeError: CUDABackend.make_ttir() missing 1 required positional argument: 'capability'

Root Cause

Investigation Path

  1. Got error about missing capability argument
  2. Found GitHub issue jax-ml/jax-triton#365 discussing version mismatches
  3. Discovered commit 0c6c888 (one commit before main) works with Triton 3.5.x

Solution

Updated pyproject.toml and dev-requirements.txt to use specific working commit:

# pyproject.toml "jax-triton @ git+https://github.com/jax-ml/jax-triton.git@0c6c888" # dev-requirements.txt git+https://github.com/jax-ml/jax-triton.git@0c6c888

Key Learning: When upgrading major dependencies, intermediate library versions may be incompatible; pinning to specific commits can bridge compatibility gaps.


Problem 5: PTXAS Version Mismatch

Issue

Triton's bundled PTXAS assembler doesn't recognize sm_121a architecture.

Error Encountered

ptxas fatal: PTX with .target 'sm_121a' cannot be compiled for architecture 'sm_121'

Root Cause

Investigation Path

  1. Initially thought environment variable TRITON_PTXAS_PATH would solve it
  2. Realized the problem was dual: (a) Triton generating sm_121a AND (b) PTXAS not accepting it
  3. Found GitHub issue triton-lang/triton#8539 confirming PTXAS version issue
  4. PR #8543 merged to fix this permanently

Solution

Set environment variable to use newer PTXAS:

export TRITON_PTXAS_PATH=/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas

Added to Dockerfile:

ENV TRITON_PTXAS_PATH="/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas"

Key Learning: Compiler toolchains include multiple binaries (compiler + assembler); version mismatches can occur at any stage.


Final Working Configuration

Dockerfile Changes

  1. Triton: Build from main branch with sm_121 support
  2. Suffix Patch: Change capability >= 90 to capability == 90
  3. PTXAS Path: Set environment variable to use nvidia-cuda-nvcc PTXAS
# Line 77: Use Triton main instead of 3.3.x git checkout main # Line 79: Apply suffix patch sed -i 's/suffix = "a" if capability >= 90 else ""/suffix = "a" if capability == 90 else ""/' \ third_party/nvidia/backend/compiler.py # Line 108: Set PTXAS path ENV TRITON_PTXAS_PATH="/alphafold3_venv/lib/python3.12/site-packages/nvidia/cuda_nvcc/bin/ptxas"

pyproject.toml Changes

  1. Remove triton: Let Triton be built from source (line 28 removed)
  2. Update jax-triton: Use compatible commit
dependencies = [ "absl-py", "chex", "dm-haiku==0.0.14", "dm-tree", "jax==0.6.0", "jax[cuda12]==0.6.0", "jax-triton @ git+https://github.com/jax-ml/jax-triton.git@0c6c888", # Updated # triton==3.3.1 removed - built from source "jaxtyping==0.3.2", "numpy", "rdkit==2024.3.5", "tqdm", "typeguard==2.13.3", "zstandard", ]

AlphaFold 3 Source Changes

Modified src/alphafold3/jax/common/triton_utils.py:

Lines 68-82: Updated dot function and all tl.* operations


Technical Stack Summary

Working Versions

Architecture Details


Key Rabbit Holes & Dead Ends

Rabbit Hole 1: Following GH200 Instructions

What we tried: Following GitHub issue #394 instructions for Triton 3.3.1 Why it failed: Issue was for GH200 (Hopper/sm_90), not GB10 (Blackwell/sm_121) Lesson: Always verify hardware architecture matches the compatibility guide

Rabbit Hole 2: Assuming Suffix Logic Was Correct

What we tried: Initially focused on LLVM intrinsic errors Why it failed: Didn't realize 'a' suffix was being incorrectly applied Lesson: Architecture naming conventions have specific meanings; investigate suffix logic

Rabbit Hole 3: Trying to Use jax-triton main

What we tried: Upgraded to jax-triton from main branch Why it failed: Main branch uses native_specialize_impl that doesn't exist in Triton 3.5.x Lesson: Library main branches may target unreleased dependency versions

Rabbit Hole 4: Only Setting TRITON_PTXAS_PATH

What we tried: Set PTXAS environment variable without patching suffix logic Why it failed: Triton was still generating sm_121a which even newer PTXAS couldn't handle correctly Lesson: Environment variables alone can't fix code generation issues


Testing Results

Test Command

docker run --name af3_test_gb10 --rm \ --volume /home/ruh/research/PhD/related_work/alphafold3/af_input:/root/af_input \ --volume /home/ruh/research/PhD/related_work/alphafold3/af_output:/root/af_output \ --volume /home/ruh/data/af3_model:/root/models \ --volume /home/ruh/data/af_public_databases:/root/public_databases \ --volume /home/ruh/data/af3_jax_cache:/root/jax_cache \ --gpus all --memory=64g \ -e XLA_PYTHON_CLIENT_PREALLOCATE=false \ -e TF_FORCE_UNIFIED_MEMORY=true \ -e XLA_CLIENT_MEM_FRACTION=3.2 \ alphafold3 python run_alphafold_test.py

Expected Behavior


References

GitHub Issues

Key Commits


Recommendations for Future Work

For AlphaFold 3 Team

  1. Update official compatibility guide to distinguish Hopper (GH200) vs Blackwell (GB10)
  2. Consider maintaining separate branches for different Triton versions
  3. Add CI testing for Blackwell architecture

For Users with Blackwell GPUs

  1. Always build Triton from source (main branch) until official 3.6+ release
  2. Apply suffix patch as standard practice
  3. Use CUDA 12.9+ or 13.0+ for best compatibility
  4. Pin jax-triton to commit 0c6c888 until official release supports Triton 3.5+

For Docker Deployments

  1. Set TRITON_PTXAS_PATH environment variable in container
  2. Build Triton during image creation, not at runtime
  3. Cache JAX compilation artifacts to avoid recompilation

Conclusion

Getting AlphaFold 3 running on GB10 Blackwell required navigating a complex dependency chain:

CUDA 12.9Triton main (with suffix patch) → jax-triton 0c6c888AlphaFold 3 (with _semantic API)

The main challenges were:

  1. Architecture support: Required bleeding-edge Triton/LLVM for sm_121
  2. Suffix logic: Architectural detail that broke compilation
  3. API evolution: Triton 3.3 → 3.5 API changes
  4. Toolchain versions: PTXAS assembler version mismatch

Total time investment: ~2-3 hours of debugging across multiple GitHub issues and commit histories.

The experience highlights the challenges of running scientific software on cutting-edge hardware where the software ecosystem is still catching up to hardware releases.


Problem 6: CUDA Runtime Kernel Launch Error (ONGOING)

Issue

After fixing compilation issues, getting runtime CUDA error during kernel execution.

Error Encountered

CUDA_ERROR_INVALID_VALUE operation gpuLaunchKernel(...) failed Fatal Python error: Segmentation fault

Root Cause (Suspected)

Triton kernels may be using launch parameters (block dimensions, grid dimensions, or shared memory) that are incompatible with GB10 Blackwell architecture limits or have bugs specific to sm_121.

Investigation Path

  1. Fixed compilation (sm_121a → sm_121) ✓
  2. Kernel compiles and loads successfully ✓
  3. Kernel launch fails with INVALID_VALUE
  4. Likely issues:

Potential Solutions

  1. Check Triton version: Ensure using latest main with all Blackwell fixes
  2. Disable Triton kernels: Test if falling back to XLA kernels works
  3. Debug kernel parameters: Add logging to see actual launch parameters
  4. Report to Triton: This may be a bug in Triton's Blackwell kernel generation
  5. Try different Triton commit: Use a more recent commit from main branch

Status

⚠️ BLOCKED ON UPSTREAM - Compilation works, runtime blocked by incomplete JAX Blackwell support

Root Cause Identified

The kernel launch failure is due to incomplete Blackwell support in JAX/jaxlib, not Triton:

Attempted Solutions

  1. Triton suffix patch applied - sm_121asm_121 fixed
  2. PTXAS path configured - Using CUDA 12.9 PTXAS
  3. jax-triton compatibility tested - Tried 0c6c888, 6b9682a, and main
  4. All jax-triton versions fail:

Status: ⚠️ BLOCKED - Waiting for upstream JAX Blackwell support Date: November 15, 2025 Blocking Issues:

What Works:

What Doesn't Work:

Next Steps (When Unblocked):

  1. Wait for JAX team to complete Blackwell support in issue #31399
  2. Upgrade to JAX 0.9.0+ when released with full Blackwell support
  3. Rebuild Docker image with updated JAX/jaxlib
  4. Test AlphaFold 3 with full Blackwell stack

Estimated Timeline: