Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions unsloth_zoo/vllm_lora_worker_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,24 @@ def list_adapters(self) -> Set[int]:
else:
return set(self._adapter_manager.list_adapters())

def supports_tower_connector_lora(self) -> bool:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see where this comes from. Is this a necessity or a good to have?
Cuz from the looks of it, this is primarily aimed at mm_proj or those kind of modules LoRA and is kinda experimental acc to their docs

# vLLM v1 vision code expects this method on the worker manager.
# Keep compatibility with older manager implementations by defaulting
# to False when the adapter manager does not expose this capability.
adapter_manager = getattr(self, "_adapter_manager", None)
if adapter_manager is None:
return False
if not bool(getattr(adapter_manager, "supports_mm", True)):
return False

capability = getattr(adapter_manager, "supports_tower_connector_lora", False)
if callable(capability):
try:
capability = capability()
except Exception:
return False
return bool(capability)


# from vllm try to import WorkerLoRAManager
try:
Expand Down
78 changes: 78 additions & 0 deletions unsloth_zoo/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,79 @@ def patch_vllm_lora_load_tensors():
pass
pass

def patch_vllm_multimodal_seq2text():
# vLLM multimodal prompt updates can pass token payloads as nested
# dict/list/tensor objects. Normalize these payloads before decode.
try:
import vllm.multimodal.processing.processor as mm_processor
original_seq2text = mm_processor._seq2text
if getattr(original_seq2text, "__unsloth_patched_seq2text__", False):
return

def _extract_token_ids(payload, depth = 0):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we're heavily over complicating this. We should ideally check for 3-4 known things
list(int), list(list(int)), tensor. All the depth thing seems complicated

if depth >= 8:
return None
if isinstance(payload, str):
return payload
Comment on lines +397 to +398
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function _extract_token_ids has an inconsistent return type, as it can return a list of integers, None, or a string. This mixed responsibility makes the logic complex and potentially buggy, especially for list processing. For example, _extract_token_ids(['hello']) would return None because the recursive call for 'hello' returns a string, which is not a list.

To make the function's responsibility clearer (to only extract token IDs), it's better for it to only return a list of integers or None. String payloads can be handled by the wrapper function unsloth_seq2text.

if torch.is_tensor(payload):
payload = payload.tolist()
elif isinstance(payload, np.ndarray):
payload = payload.tolist()
if isinstance(payload, (int, np.integer)):
return [int(payload)]
if isinstance(payload, tuple):
payload = list(payload)
if isinstance(payload, list):
if len(payload) == 0:
return payload
if all(isinstance(x, (int, np.integer)) for x in payload):
return [int(x) for x in payload]
if len(payload) == 1:
return _extract_token_ids(payload[0], depth = depth + 1)
merged = []
for item in payload:
found = _extract_token_ids(item, depth = depth + 1)
if not isinstance(found, list):
return None
merged.extend(found)
return merged
if isinstance(payload, dict):
preferred_keys = (
"token_ids",
"input_ids",
"ids",
"prompt_token_ids",
"token_id",
"tokens",
"prompt",
"content",
)
for key in preferred_keys:
if key in payload:
found = _extract_token_ids(payload[key], depth = depth + 1)
if found is not None:
return found
for value in payload.values():
found = _extract_token_ids(value, depth = depth + 1)
if found is not None:
return found
return None
pass

@functools.wraps(original_seq2text)
def unsloth_seq2text(tokenizer, seq, *, use_cache = True):
normalized_seq = _extract_token_ids(seq)
if normalized_seq is None:
normalized_seq = seq
return original_seq2text(tokenizer, normalized_seq, use_cache = use_cache)
pass

unsloth_seq2text.__unsloth_patched_seq2text__ = True
mm_processor._seq2text = unsloth_seq2text
except:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except: can hide unexpected errors and makes debugging difficult. It's better to catch a specific exception, like ImportError if you're only concerned about the module not being found, or at least Exception to avoid catching system-exiting exceptions like KeyboardInterrupt.

Suggested change
except:
except Exception:

pass
pass

def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
Expand Down Expand Up @@ -409,6 +482,10 @@ def patch_vllm_lora_load_tensors():
return
pass

def patch_vllm_multimodal_seq2text():
return
pass

def patch_vllm_set_inductor_config():
return
pass
Expand Down Expand Up @@ -788,6 +865,7 @@ def patch_vllm(debug = True):
patch_bitsandbytes_quant_state()
patch_vllm_bitsandbytes()
patch_vllm_lora_tokenizer()
patch_vllm_multimodal_seq2text()
patch_vllm_lora_load_tensors()
if os.getenv("UNSLOTH_VLLM_STANDBY", "0") == "1":
if Version("0.10.0") <= Version(vllm_version) < Version("0.11.0"):
Expand Down