Skip to content

flash norm + blackwell softmax#4

Open
Lazarus-931 wants to merge 6 commits into
patrick-toulme:mainfrom
Lazarus-931:main
Open

flash norm + blackwell softmax#4
Lazarus-931 wants to merge 6 commits into
patrick-toulme:mainfrom
Lazarus-931:main

Conversation

@Lazarus-931

Copy link
Copy Markdown
Contributor

Per req of #2, added flash norm

There is a per kernel gain of 5-7%, didn't test fully as @fm1320 had said, and I do thing the 12% gains materialize when in a 256-token decode with 33 norm call. Per paper, I made it two sep @_kernel calls, one launching xW and another doing RMS(x).

Also included softmax for blackwell, ~2.7x faster than torch, but again std::exp is slower in torch than log2e used in this, so not really a accurate comparison, as @ezyang mentioned.

Comment thread pyptx/jax_support.py
kernel_name = _extract_entry_name(ptx_source)

err, module = driver.cuModuleLoadData(ptx_source.encode())
if err == driver.CUresult.CUDA_ERROR_UNSUPPORTED_PTX_VERSION:

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are the changes to this file needed?

@Lazarus-931 Lazarus-931 May 3, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for blackwell softmax, i kept hitting compiling errors, apparently torch's bundled CUDA driver JIT is built for sm_50–sm_90, so cuModuleLoadData always gave me a CUDA_ERROR_INVALID_PTX.

errors like:

Found GPU0 NVIDIA B200 which is of compute capability (CC) 10.0.
The following list shows the CCs this version of PyTorch was built
for and the hardware CCs it supports:

  • 5.0 which supports hardware CC >=5.0,<6.0
  • 6.0 which supports hardware CC >=6.0,<7.0
  • 7.0 which supports hardware CC >=7.0,<8.0
  • 7.5 which supports hardware CC >=7.5,<8.0
  • 8.0 which supports hardware CC >=8.0,<9.0
  • 8.6 which supports hardware CC >=8.6,<9.0
  • 9.0 which supports hardware CC >=9.0,<10.0

@patrick-toulme patrick-toulme left a comment

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confused about why jax_support.py needs changes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants