Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 34 additions & 22 deletions fastdeploy/model_executor/models/paddleocr_vl/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import math
from typing import Optional

import numpy as np
import paddle
import paddle.nn as nn

Expand Down Expand Up @@ -63,30 +64,41 @@ def __init__(self, text_config, vision_config, prefix=""):
self.linear_2 = nn.Linear(self.hidden_size, self.text_config.hidden_size)
self.linear_2.weight.weight_loader = self.weight_loader

def forward(self, image_features, image_grid_thw):
def _build_merge_permutation(self, image_grid_thw):
m1, m2 = self.merge_kernel_size
if isinstance(image_grid_thw, paddle.Tensor):
image_grid_thw = image_grid_thw.cpu().numpy()

merge_indices = []
merge_lengths = []
start = 0
for image_grid in image_grid_thw:
t, h, w = map(int, image_grid)
assert h % m1 == 0 and w % m2 == 0, (image_grid, self.merge_kernel_size)
local = np.arange(t * h * w, dtype=np.int64).reshape((t, h // m1, m1, w // m2, m2))
local = local.transpose((0, 1, 3, 2, 4)).reshape(-1)
merge_indices.append(local + start)
merge_lengths.append(t * (h // m1) * (w // m2))
start += t * h * w

if len(merge_indices) == 0:
return np.empty((0,), dtype=np.int64), merge_lengths
return np.concatenate(merge_indices, axis=0), merge_lengths

def forward(self, image_features, image_grid_thw, return_packed: bool = False):
if isinstance(image_features, (list, tuple)):
processed_features = list()
for image_feature, image_grid in zip(image_features, image_grid_thw):
image_feature = self.pre_norm(image_feature) # shape: (T*H*W, D)
t, h, w = image_grid
from einops import rearrange

image_feature = rearrange(
image_feature,
"(t h p1 w p2) d -> (t h w) (p1 p2 d)",
t=int(t),
h=int(h // m1),
p1=int(m1),
w=int(w // m2),
p2=int(m2),
)
hidden_states = self.linear_1(image_feature)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
processed_features.append(hidden_states)

return processed_features
packed_image_features = image_features[0] if len(image_features) == 1 else paddle.concat(image_features, axis=0)
packed_image_features = self.pre_norm(packed_image_features)
merge_indices, merge_lengths = self._build_merge_permutation(image_grid_thw)
merge_indices = paddle.to_tensor(merge_indices, dtype="int64", place=packed_image_features.place)
packed_image_features = paddle.index_select(packed_image_features, merge_indices, axis=0)
hidden_states = paddle.reshape(packed_image_features, [-1, self.hidden_size])
hidden_states = self.linear_1(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.linear_2(hidden_states)
if return_packed:
return hidden_states
return list(paddle.split(hidden_states, merge_lengths, axis=0))

dim = image_features.shape[-1]
image_features = paddle.reshape(image_features, [-1, dim])
Expand Down
62 changes: 46 additions & 16 deletions fastdeploy/model_executor/models/paddleocr_vl/siglip.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ def forward(
cos_emb: Optional[paddle.Tensor] = None, # (cos, sin)
sin_emb: Optional[paddle.Tensor] = None, # (cos, sin)
):
B, seq_length, D = hidden_states.shape
if hidden_states.dim() == 3:
assert hidden_states.shape[0] == 1, f"SiglipAttention only supports batch=1, got {hidden_states.shape}"
hidden_states = hidden_states[0]

seq_length, D = hidden_states.shape
qkv = self.qkv_proj(hidden_states)
q, k, v = neox_rope_embedding(qkv, cos_emb, sin_emb, self.num_heads, self.head_dim)
attn_output = self.flash_attn_func(
Expand Down Expand Up @@ -255,25 +259,26 @@ def forward(
flatten_image_grid_thw = self.flatten_list(image_grid_thw)
flatten_image_grid_thw = np.array(flatten_image_grid_thw)
assert batch_size == 1
start = 0

assert sum([np.prod(x) for x in flatten_image_grid_thw]) == embeddings.shape[1], (
flatten_image_grid_thw,
embeddings.shape,
)
embeddings = embeddings.squeeze(0)
tmp_embeddings = list()
for image_grid in image_grid_thw:
t, h, w = image_grid
end = start + t * h * w
image_embeddings = embeddings[int(start) : int(end), :]
position_embedding = (
self.interpolate_pos_encoding(image_embeddings, h, w, True).squeeze(0).tile((t, 1))
).astype(image_embeddings.dtype)
image_embeddings = image_embeddings + position_embedding
tmp_embeddings.append(image_embeddings)
start = end
embeddings = paddle.concat(tmp_embeddings, axis=0).unsqueeze(0)
packed_position_embeddings = []
for t, h, w in flatten_image_grid_thw:
t, h, w = map(int, (t, h, w))
position_embedding = self.fetch_position_embedding_lfu_cache(embeddings, h, w).squeeze(0)
if t > 1:
position_embedding = position_embedding.tile((t, 1))
if position_embedding.dtype != embeddings.dtype:
position_embedding = position_embedding.astype(embeddings.dtype)
packed_position_embeddings.append(position_embedding)
if len(packed_position_embeddings) == 1:
packed_position_embeddings = packed_position_embeddings[0]
else:
packed_position_embeddings = paddle.concat(packed_position_embeddings, axis=0)
embeddings = (embeddings + packed_position_embeddings).unsqueeze(0)
else:
embeddings = embeddings + self.packing_position_embedding(position_ids)
return embeddings
Expand Down Expand Up @@ -307,7 +312,7 @@ def weight_loader(self, param, loaded_weight, loaded_shard_id: Optional[str] = N

def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor:
hidden_states = self.fc1(hidden_states)
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states[0])
hidden_states = get_activation_fn(self.config.hidden_act)(hidden_states)
hidden_states = self.fc2(hidden_states)
return hidden_states

Expand All @@ -331,6 +336,32 @@ def forward(
cos_emb=None,
sin_emb=None,
):
if hidden_states.dim() == 3 and hidden_states.shape[0] == 1:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 SiglipEncoderLayer.forward 新增的 batch=1 快速路径展开了与通用路径完全相同的 attention → residual → MLP → residual 逻辑,导致同一模块存在两份等价代码。

若后续通用路径需要修改(如新增 LayerDrop、gradient checkpointing 等),需同步维护两处,维护成本较高。

建议将公共逻辑提取为私有方法,例如:

def _forward_single(self, hidden_states, ...):
    residual = hidden_states
    ...
    return residual + self.mlp(self.layer_norm2(residual + self.self_attn(...)))

batch=1 快速路径只做 squeeze/unsqueeze,通用路径直接调用即可。

hidden_states = hidden_states[0]

residual = hidden_states
ln1_out = self.layer_norm1(hidden_states)

x = self.self_attn(
hidden_states=ln1_out,
attention_mask=attention_mask,
output_attentions=output_attentions,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
cos_emb=cos_emb,
sin_emb=sin_emb,
)

hs_post_attn = residual + x

residual = hs_post_attn
ln2_out = self.layer_norm2(residual)

mlp_out = self.mlp(ln2_out)

hidden_states_out = residual + mlp_out

return (hidden_states_out.unsqueeze(0),)

This comment was marked as outdated.


residual = hidden_states
############################
Expand Down Expand Up @@ -677,7 +708,6 @@ def forward(
end = cu_seqlens[i + 1]
tensor = last_hidden_state[:, start:end, :].squeeze(0)
sample_hidden_state.append(tensor)

return sample_hidden_state


Expand Down
36 changes: 21 additions & 15 deletions fastdeploy/model_executor/models/paddleocr_vl/siglip_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,33 @@ def rotate_half(x):


def apply_rotary_pos_emb_vision(x, cos, sin):

This comment was marked as outdated.

orig_dtype = x.dtype
x = x.astype("float32")
x_embed = (x * cos) + (rotate_half(x) * sin)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 原来 apply_rotary_pos_emb_vision 内部有 x.astype("float32") 精度保护,现在完全移除,改为依赖调用方在传入前完成转换。

目前 native_neox_rope_embedding 已在外层做了:

if q_dtype != paddle.float32:
    qk = qkv[:, :2].astype("float32")

所以当前路径是安全的。但 apply_rotary_pos_emb_vision 是模块级公共函数,若被其他场景直接调用(bfloat16 输入),精度保护丢失。

建议在函数 docstring 中注明 x 须为 float32,或在函数入口加轻量断言:

assert x.dtype == paddle.float32, f"expected float32, got {x.dtype}"

return x_embed.astype(orig_dtype)
return x_embed


def native_neox_rope_embedding(qkv, cos, sin, num_heads):
B, seq_length, D = qkv.shape
if seq_length == -1:
_, seq_length, _ = paddle.shape(qkv)
qkv = qkv.reshape(
[
seq_length,
3,
num_heads,
-1,
]
).transpose(perm=[1, 0, 2, 3])
q, k, v = qkv.unbind(axis=0)
if qkv.dim() == 3:
B, seq_length, D = qkv.shape
if seq_length == -1:
_, seq_length, _ = paddle.shape(qkv)
token_count = B * seq_length
else:
token_count, D = qkv.shape
if token_count == -1:
token_count, _ = paddle.shape(qkv)
qkv = qkv.reshape([token_count, 3, num_heads, -1])
q_dtype = qkv.dtype
if q_dtype != paddle.float32:
qk = qkv[:, :2].astype("float32")
q, k = qk[:, 0], qk[:, 1]
else:
q, k = qkv[:, 0], qkv[:, 1]
v = qkv[:, 2]
q = apply_rotary_pos_emb_vision(q, cos, sin)
k = apply_rotary_pos_emb_vision(k, cos, sin)
if q.dtype != q_dtype:
q = q.astype(q_dtype)
k = k.astype(q_dtype)
return q, k, v


Expand Down
Loading
Loading