Skip to content

DebertaV2 fix for running with large batches#846

Open
vrdn-23 wants to merge 1 commit intohuggingface:mainfrom
vrdn-23:vrdn-23/batching-deberta-fix
Open

DebertaV2 fix for running with large batches#846
vrdn-23 wants to merge 1 commit intohuggingface:mainfrom
vrdn-23:vrdn-23/batching-deberta-fix

Conversation

@vrdn-23
Copy link
Copy Markdown
Contributor

@vrdn-23 vrdn-23 commented Mar 20, 2026

What does this PR do?

Fix shape mismatch in DeBERTa v2 embeddings mask during batched inference

Problem

DeBERTa v2 models fail with a shape mismatch in broadcast_mul error under concurrent load when requests get batched together (batch_size > 1 with padding).

{
  "level": "ERROR",
  "message": "shape mismatch in broadcast_mul, lhs: [2, 348, 768], rhs: [2, 348, 1, 1]"
}

At 50 concurrent users, 91% of requests fail with this error. Single requests always succeed because they bypass the padding/masking path.

Root Cause

In DebertaV2Embeddings::forward, the mask reshape guard at line 179 compared shape values instead of rank (number of dimensions):

// Bug: compares [2, 348, 1] != [2, 348, 768] → true (different values)
if mask.dims() != embeddings.dims() {

When batch_size > 1 with padding, the attention mask is created as [B, L, 1] (3D) and embeddings are [B, L, H] (3D). Same rank, different values. The condition incorrectly evaluates to true, causing an unnecessary unsqueeze(2) that produces a 4D tensor [B, L, 1, 1] which cannot broadcast with the 3D embeddings [B, L, H].

This only affects DeBERTa v2 — no other model applies a mask inside the embeddings layer.

Fix

Compare tensor rank instead of shape values:

// Fixed: compares 3 == 3 → false, skips reshape (mask already broadcasts correctly)
if mask.dims().len() != embeddings.dims().len() {

A 3D mask [B, L, 1] already broadcasts correctly with [B, L, H]. The reshape block is only needed when the mask has a different number of dimensions (e.g., 2D [B, L] → needs unsqueeze to become [B, L, 1]).

Verification

Load tested with k6 ramping to 50 concurrent users:

Metric Before After
Success rate 8.6% 100%
Shape mismatch errors 26,142 0

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the documentation guidelines.
  • Did you write any new necessary tests? If applicable, did you include or update the insta snapshots?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@alvarobartt @kozistr

@vrdn-23
Copy link
Copy Markdown
Contributor Author

vrdn-23 commented Mar 20, 2026

The load test I ran to replicate this. It is easy to see the error from main when running this

import http from 'k6/http';
import { check, sleep } from 'k6';
import { Counter } from 'k6/metrics';

const shapeMismatchErrors = new Counter('shape_mismatch_errors');

export const options = {
  stages: [
    { duration: '10s', target: 2 },
    { duration: '10s', target: 10 },
    { duration: '10s', target: 25 },
    { duration: '10s', target: 50 },
    { duration: '10s', target: 50 },
  ],
  insecureSkipTLSVerify: true,
};

const inputs = [
  "What is Deep Learning?",
  "How does a transformer model work in natural language processing?",
  "AI",
  "The quick brown fox jumps over the lazy dog and then runs across the field to find shelter from the incoming storm",
  "Hello",
  "Explain quantum computing in simple terms for a beginner who has no background in physics or computer science",
  "Short",
  "This is a medium length sentence for testing purposes",
];

export default function () {
  const payload = JSON.stringify({
    inputs: inputs[Math.floor(Math.random() * inputs.length)],
  });

  const res = http.post('http://localhost:8000/predict', payload, {
    headers: { 'Content-Type': 'application/json' },
  });

  const isShapeMismatch = res.status !== 200 &&
    res.body && res.body.includes('shape mismatch');

  if (isShapeMismatch) {
    shapeMismatchErrors.add(1);
    console.error(`SHAPE MISMATCH ERROR at VU ${__VU}: ${res.body}`);
  }

  check(res, {
    'status is 200': (r) => r.status === 200,
  });
}

@michaelfeil
Copy link
Copy Markdown
Contributor

@vrdn-23 How is this not caught during warmup? Isn't that what warmup is for?

@vrdn-23
Copy link
Copy Markdown
Contributor Author

vrdn-23 commented Mar 20, 2026

I think it's cause warmup always creates batches of same size and this particular branch of the code for padding/masking only gets activated when we have batches of unequal length. I think it might also be helpful to have the warmup run for batches of same size (which is max size to ensure GPU has memory) and unequal sizes (to help check padding issues). Or I don't know if that is overkill. Any thoughts @michaelfeil ?

@alvarobartt alvarobartt added this to the v1.10.0 milestone Mar 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants