-
Notifications
You must be signed in to change notification settings - Fork 742
[Metax][Optimization] 优化 PaddleOCR-VL 在 Metax GPU 上的视觉路径开销 #7619
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,27 +37,33 @@ def rotate_half(x): | |
|
|
||
|
|
||
| def apply_rotary_pos_emb_vision(x, cos, sin): | ||
This comment was marked as outdated.
Sorry, something went wrong. |
||
| orig_dtype = x.dtype | ||
| x = x.astype("float32") | ||
| x_embed = (x * cos) + (rotate_half(x) * sin) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ❓ 疑问 原来 目前 if q_dtype != paddle.float32:
qk = qkv[:, :2].astype("float32")所以当前路径是安全的。但 建议在函数 docstring 中注明 assert x.dtype == paddle.float32, f"expected float32, got {x.dtype}" |
||
| return x_embed.astype(orig_dtype) | ||
| return x_embed | ||
|
|
||
|
|
||
| def native_neox_rope_embedding(qkv, cos, sin, num_heads): | ||
| B, seq_length, D = qkv.shape | ||
| if seq_length == -1: | ||
| _, seq_length, _ = paddle.shape(qkv) | ||
| qkv = qkv.reshape( | ||
| [ | ||
| seq_length, | ||
| 3, | ||
| num_heads, | ||
| -1, | ||
| ] | ||
| ).transpose(perm=[1, 0, 2, 3]) | ||
| q, k, v = qkv.unbind(axis=0) | ||
| if qkv.dim() == 3: | ||
| B, seq_length, D = qkv.shape | ||
| if seq_length == -1: | ||
| _, seq_length, _ = paddle.shape(qkv) | ||
| token_count = B * seq_length | ||
| else: | ||
| token_count, D = qkv.shape | ||
| if token_count == -1: | ||
| token_count, _ = paddle.shape(qkv) | ||
| qkv = qkv.reshape([token_count, 3, num_heads, -1]) | ||
| q_dtype = qkv.dtype | ||
| if q_dtype != paddle.float32: | ||
| qk = qkv[:, :2].astype("float32") | ||
| q, k = qk[:, 0], qk[:, 1] | ||
| else: | ||
| q, k = qkv[:, 0], qkv[:, 1] | ||
| v = qkv[:, 2] | ||
| q = apply_rotary_pos_emb_vision(q, cos, sin) | ||
| k = apply_rotary_pos_emb_vision(k, cos, sin) | ||
| if q.dtype != q_dtype: | ||
| q = q.astype(q_dtype) | ||
| k = k.astype(q_dtype) | ||
| return q, k, v | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🟡 建议
SiglipEncoderLayer.forward新增的 batch=1 快速路径展开了与通用路径完全相同的 attention → residual → MLP → residual 逻辑,导致同一模块存在两份等价代码。若后续通用路径需要修改(如新增 LayerDrop、gradient checkpointing 等),需同步维护两处,维护成本较高。
建议将公共逻辑提取为私有方法,例如:
batch=1 快速路径只做 squeeze/unsqueeze,通用路径直接调用即可。