Skip to content

[ROCm] Enable BF16 top-p sampling kernel#53

Open
austin1997 wants to merge 1 commit into
ROCm:paddle_hackthonfrom
austin1997:rocm-bf16-top-p-sampling
Open

[ROCm] Enable BF16 top-p sampling kernel#53
austin1997 wants to merge 1 commit into
ROCm:paddle_hackthonfrom
austin1997:rocm-bf16-top-p-sampling

Conversation

@austin1997

Copy link
Copy Markdown

PR Category

Custom Device

PR Types

Bug fixes

Description

Enable BF16 support for the ROCm top_p_sampling GPU kernel.

This PR:

  • Maps phi::bfloat16 to hip_bfloat16 for HIP radix sort keys when HIP_VERSION >= 60100000.
  • Registers the top_p_sampling BF16 GPU kernel on ROCm when HIP BF16 is available.
  • Adds focused dygraph BF16 coverage for top_p_sampling.

Verification:

  • env TARGET=SKYLAKEX ninja -j 160 paddle_python
  • BF16 top_p_sampling repro on ROCm GPU
  • python3.12 -m unittest test_top_p_sampling.TestTopPAPIBF16
  • python3.12 test_top_p_sampling.py
  • prek run --files paddle/phi/kernels/gpu/top_p_sampling_kernel.cu test/legacy_test/test_top_p_sampling.py
  • git diff --check

是否引起精度变化

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant