Skip to content

Commit 4956871

Browse files
authored
Fix odd dims in CUDA, and add odd dims support for P016 (#1462)
1 parent 20080ca commit 4956871

13 files changed

Lines changed: 180 additions & 66 deletions

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -362,33 +362,27 @@ static torch::stable::Tensor convertP016FrameToRGB16(
362362
int bitDepth,
363363
const float colorMatrix[3][4],
364364
bool colorMatrixChanged) {
365-
int height = avFrame->height;
366-
int width = avFrame->width;
367-
STD_TORCH_CHECK(
368-
height % 2 == 0 && width % 2 == 0,
369-
"convertP016FrameToRGB16 expects even avFrame dimensions, got ",
370-
height,
371-
"x",
372-
width,
373-
". Report a bug if you see this message.");
374-
STD_TORCH_CHECK(
375-
outputDims.height == height && outputDims.width == width,
376-
"outputDims ",
377-
outputDims.height,
378-
"x",
379-
outputDims.width,
380-
" are not consistent with avFrame dimensions ",
381-
height,
382-
"x",
383-
width,
384-
". Report a bug if you see this message.");
365+
// avFrame dimensions may be odd (NVDEC display area for VP9 etc.). P016
366+
// color conversion requires even dimensions, so we round up to even for the
367+
// kernel, then crop to outputDims.
368+
int frameHeight = avFrame->height;
369+
int frameWidth = avFrame->width;
370+
int height = roundUpToEven(frameHeight);
371+
int width = roundUpToEven(frameWidth);
372+
373+
int outHeight = outputDims.height;
374+
int outWidth = outputDims.width;
375+
bool needsCrop = (outHeight != height) || (outWidth != width);
385376

386377
torch::stable::Tensor dst;
387-
if (preAllocatedOutputTensor.has_value()) {
378+
if (needsCrop) {
379+
dst = allocateEmptyHWCTensor(
380+
FrameDims(height, width), device, OutputDtype::FLOAT32);
381+
} else if (preAllocatedOutputTensor.has_value()) {
388382
dst = preAllocatedOutputTensor.value();
389383
} else {
390384
dst = allocateEmptyHWCTensor(
391-
FrameDims(height, width), device, OutputDtype::FLOAT32);
385+
FrameDims(outHeight, outWidth), device, OutputDtype::FLOAT32);
392386
}
393387

394388
cudaStream_t stream = getCurrentCudaStream(device.index());
@@ -408,6 +402,20 @@ static torch::stable::Tensor convertP016FrameToRGB16(
408402
colorMatrixChanged,
409403
stream);
410404

405+
if (needsCrop) {
406+
if (outHeight != height) {
407+
dst = torch::stable::narrow(dst, /*dim=*/0, /*start=*/0, outHeight);
408+
}
409+
if (outWidth != width) {
410+
dst = torch::stable::narrow(dst, /*dim=*/1, /*start=*/0, outWidth);
411+
dst = torch::stable::contiguous(dst);
412+
}
413+
if (preAllocatedOutputTensor.has_value()) {
414+
torch::stable::copy_(preAllocatedOutputTensor.value(), dst);
415+
return preAllocatedOutputTensor.value();
416+
}
417+
return dst;
418+
}
411419
return dst;
412420
}
413421

@@ -895,9 +903,13 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
895903
? AVCOL_RANGE_JPEG
896904
: AVCOL_RANGE_MPEG;
897905

898-
// Below: Ask Claude. I'm not going to even pretend.
906+
// NVDEC's surface layout places the UV plane after the Y plane. For
907+
// NV12/P016 the Y plane has an even number of rows (NVDEC rounds up
908+
// internally), so we must use the rounded-up height for the UV offset.
909+
unsigned int evenHeight = roundUpToEven(height);
899910
avFrame->data[0] = reinterpret_cast<uint8_t*>(framePtr);
900-
avFrame->data[1] = reinterpret_cast<uint8_t*>(framePtr + (pitch * height));
911+
avFrame->data[1] =
912+
reinterpret_cast<uint8_t*>(framePtr + (pitch * evenHeight));
901913
avFrame->data[2] = nullptr;
902914
avFrame->data[3] = nullptr;
903915
avFrame->linesize[0] = pitch;

src/torchcodec/_core/CUDACommon.cpp

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -289,33 +289,16 @@ torch::stable::Tensor convertNV12FrameToRGB(
289289
cudaStream_t nvdecStream,
290290
std::optional<torch::stable::Tensor> preAllocatedOutputTensor,
291291
const FrameDims& outputDims) {
292-
// avFrame dimensions must be even (NV12 requirement). When the original
293-
// video has odd dimensions, the caller pads to even and passes the original
294-
// size via outputDims so we can crop after conversion.
295-
int nv12Height = avFrame->height;
296-
int nv12Width = avFrame->width;
297-
STD_TORCH_CHECK(
298-
nv12Height % 2 == 0 && nv12Width % 2 == 0,
299-
"convertNV12FrameToRGB expects even avFrame dimensions, got ",
300-
nv12Height,
301-
"x",
302-
nv12Width,
303-
". Report a bug if you see this message.");
292+
// avFrame dimensions may be odd (NVDEC display area for VP9 etc.). NV12
293+
// color conversion requires even dimensions, so we round up to even for the
294+
// conversion, then crop to outputDims.
295+
int frameHeight = avFrame->height;
296+
int frameWidth = avFrame->width;
297+
int nv12Height = roundUpToEven(frameHeight);
298+
int nv12Width = roundUpToEven(frameWidth);
304299

305300
int outHeight = outputDims.height;
306301
int outWidth = outputDims.width;
307-
STD_TORCH_CHECK(
308-
roundUpToEven(outHeight) == nv12Height &&
309-
roundUpToEven(outWidth) == nv12Width,
310-
"outputDims ",
311-
outHeight,
312-
"x",
313-
outWidth,
314-
" are not consistent with avFrame dimensions ",
315-
nv12Height,
316-
"x",
317-
nv12Width,
318-
". Report a bug if you see this message.");
319302
bool needsCrop = (outHeight != nv12Height) || (outWidth != nv12Width);
320303

321304
torch::stable::Tensor dst;
File renamed without changes.
File renamed without changes.
68.9 KB
Binary file not shown.
68.7 KB
Binary file not shown.
67.5 KB
Binary file not shown.
66.7 KB
Binary file not shown.
69.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)