-
Notifications
You must be signed in to change notification settings - Fork 234
Fix vLLM vision GRPO compatibility for issue #4081 #520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
7375b07
d730400
f5f3fd0
47ee885
405244c
9aa9581
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -372,6 +372,79 @@ def patch_vllm_lora_load_tensors(): | |||||
| pass | ||||||
| pass | ||||||
|
|
||||||
| def patch_vllm_multimodal_seq2text(): | ||||||
| # vLLM multimodal prompt updates can pass token payloads as nested | ||||||
| # dict/list/tensor objects. Normalize these payloads before decode. | ||||||
| try: | ||||||
| import vllm.multimodal.processing.processor as mm_processor | ||||||
| original_seq2text = mm_processor._seq2text | ||||||
| if getattr(original_seq2text, "__unsloth_patched_seq2text__", False): | ||||||
| return | ||||||
|
|
||||||
| def _extract_token_ids(payload, depth = 0): | ||||||
|
||||||
| if depth >= 8: | ||||||
| return None | ||||||
| if isinstance(payload, str): | ||||||
| return payload | ||||||
|
Comment on lines
+397
to
+398
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The function To make the function's responsibility clearer (to only extract token IDs), it's better for it to only return a list of integers or |
||||||
| if torch.is_tensor(payload): | ||||||
| payload = payload.tolist() | ||||||
| elif isinstance(payload, np.ndarray): | ||||||
| payload = payload.tolist() | ||||||
| if isinstance(payload, (int, np.integer)): | ||||||
| return [int(payload)] | ||||||
| if isinstance(payload, tuple): | ||||||
| payload = list(payload) | ||||||
| if isinstance(payload, list): | ||||||
| if len(payload) == 0: | ||||||
| return payload | ||||||
| if all(isinstance(x, (int, np.integer)) for x in payload): | ||||||
| return [int(x) for x in payload] | ||||||
| if len(payload) == 1: | ||||||
| return _extract_token_ids(payload[0], depth = depth + 1) | ||||||
| merged = [] | ||||||
| for item in payload: | ||||||
| found = _extract_token_ids(item, depth = depth + 1) | ||||||
| if not isinstance(found, list): | ||||||
| return None | ||||||
| merged.extend(found) | ||||||
| return merged | ||||||
| if isinstance(payload, dict): | ||||||
| preferred_keys = ( | ||||||
| "token_ids", | ||||||
| "input_ids", | ||||||
| "ids", | ||||||
| "prompt_token_ids", | ||||||
| "token_id", | ||||||
| "tokens", | ||||||
| "prompt", | ||||||
| "content", | ||||||
| ) | ||||||
| for key in preferred_keys: | ||||||
| if key in payload: | ||||||
| found = _extract_token_ids(payload[key], depth = depth + 1) | ||||||
| if found is not None: | ||||||
| return found | ||||||
| for value in payload.values(): | ||||||
| found = _extract_token_ids(value, depth = depth + 1) | ||||||
| if found is not None: | ||||||
| return found | ||||||
| return None | ||||||
| pass | ||||||
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||
|
|
||||||
| @functools.wraps(original_seq2text) | ||||||
| def unsloth_seq2text(tokenizer, seq, *, use_cache = True): | ||||||
| normalized_seq = _extract_token_ids(seq) | ||||||
| if normalized_seq is None: | ||||||
| normalized_seq = seq | ||||||
| return original_seq2text(tokenizer, normalized_seq, use_cache = use_cache) | ||||||
| pass | ||||||
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||
|
|
||||||
| unsloth_seq2text.__unsloth_patched_seq2text__ = True | ||||||
| mm_processor._seq2text = unsloth_seq2text | ||||||
| except: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Using a bare
Suggested change
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||
| pass | ||||||
| pass | ||||||
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||
|
|
||||||
| def set_inductor_config(config, runtime_shape): | ||||||
| if isinstance(runtime_shape, int): | ||||||
| # for a specific batchsize, tuning triton kernel parameters | ||||||
|
|
@@ -409,6 +482,10 @@ def patch_vllm_lora_load_tensors(): | |||||
| return | ||||||
| pass | ||||||
|
|
||||||
| def patch_vllm_multimodal_seq2text(): | ||||||
| return | ||||||
| pass | ||||||
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
github-code-quality[bot] marked this conversation as resolved.
Fixed
Show fixed
Hide fixed
|
||||||
|
|
||||||
| def patch_vllm_set_inductor_config(): | ||||||
| return | ||||||
| pass | ||||||
|
|
@@ -788,6 +865,7 @@ def patch_vllm(debug = True): | |||||
| patch_bitsandbytes_quant_state() | ||||||
| patch_vllm_bitsandbytes() | ||||||
| patch_vllm_lora_tokenizer() | ||||||
| patch_vllm_multimodal_seq2text() | ||||||
| patch_vllm_lora_load_tensors() | ||||||
| if os.getenv("UNSLOTH_VLLM_STANDBY", "0") == "1": | ||||||
| if Version("0.10.0") <= Version(vllm_version) < Version("0.11.0"): | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see where this comes from. Is this a necessity or a good to have?
Cuz from the looks of it, this is primarily aimed at mm_proj or those kind of modules LoRA and is kinda experimental acc to their docs