Skip to content

Commit d27bfe3

Browse files
authored
Add File-like and bytes support to WavDecoder (#1461)
1 parent bab9aff commit d27bfe3

16 files changed

Lines changed: 442 additions & 215 deletions

src/torchcodec/_core/AVIOContextHolder.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,17 @@ AVIOContext* AVIOContextHolder::getAVIOContext() {
5858
return avioContext_.get();
5959
}
6060

61+
int AVIOContextHolder::read(uint8_t*, int) {
62+
STD_TORCH_CHECK(false, "read() is not supported by this AVIOContextHolder");
63+
}
64+
65+
int64_t AVIOContextHolder::seek(int64_t, int) {
66+
STD_TORCH_CHECK(false, "seek() is not supported by this AVIOContextHolder");
67+
}
68+
69+
int64_t AVIOContextHolder::getSize() {
70+
STD_TORCH_CHECK(
71+
false, "getSize() is not supported by this AVIOContextHolder");
72+
}
73+
6174
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOContextHolder.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#pragma once
88

99
#include "FFMPEGCommon.h"
10+
#include "StableABICompat.h"
1011

1112
namespace facebook::torchcodec {
1213

@@ -34,11 +35,18 @@ namespace facebook::torchcodec {
3435
// 3. A generic handle for those that just need to manage having access to an
3536
// AVIOContext, but aren't necessarily concerned with how it was customized:
3637
// typically, the SingleStreamDecoder.
37-
class AVIOContextHolder {
38+
class FORCE_PUBLIC_VISIBILITY AVIOContextHolder {
3839
public:
3940
virtual ~AVIOContextHolder();
4041
AVIOContext* getAVIOContext();
4142

43+
// Generic I/O primitives used by consumers that don't go through
44+
// FFmpeg's AVIO layer (e.g. WavDecoder). Derived classes override
45+
// the ones they support.
46+
virtual int read(uint8_t* buf, int size);
47+
virtual int64_t seek(int64_t offset, int whence);
48+
virtual int64_t getSize();
49+
4250
protected:
4351
// Make constructor protected to prevent anyone from constructing
4452
// an AVIOContextHolder without deriving it. (Ordinarily this would be
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#include "AVIOFileContext.h"
8+
9+
#include <filesystem>
10+
#include "StableABICompat.h"
11+
12+
namespace facebook::torchcodec {
13+
14+
AVIOFileContext::AVIOFileContext(const std::string& path)
15+
: file_(path, std::ios::binary) {
16+
STD_TORCH_CHECK(file_.is_open(), "Failed to open file: ", path);
17+
try {
18+
fileSize_ = static_cast<int64_t>(std::filesystem::file_size(path));
19+
} catch (const std::filesystem::filesystem_error& e) {
20+
STD_TORCH_CHECK(
21+
false, "Failed to get file size for: ", path, ". Error: ", e.what());
22+
}
23+
}
24+
25+
int AVIOFileContext::read(uint8_t* buf, int size) {
26+
file_.read(reinterpret_cast<char*>(buf), size);
27+
auto bytesRead = static_cast<int>(file_.gcount());
28+
if (bytesRead == 0) {
29+
return -1;
30+
}
31+
return bytesRead;
32+
}
33+
34+
int64_t AVIOFileContext::seek(int64_t offset, int whence) {
35+
std::ios_base::seekdir dir;
36+
switch (whence) {
37+
case SEEK_SET:
38+
dir = std::ios::beg;
39+
break;
40+
case SEEK_CUR:
41+
dir = std::ios::cur;
42+
break;
43+
case SEEK_END:
44+
dir = std::ios::end;
45+
break;
46+
default:
47+
return -1;
48+
}
49+
file_.seekg(offset, dir);
50+
STD_TORCH_CHECK(!file_.fail(), "Failed to seek in file");
51+
return static_cast<int64_t>(file_.tellg());
52+
}
53+
54+
int64_t AVIOFileContext::getSize() {
55+
return fileSize_;
56+
}
57+
58+
} // namespace facebook::torchcodec
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Meta Platforms, Inc. and affiliates.
2+
// All rights reserved.
3+
//
4+
// This source code is licensed under the BSD-style license found in the
5+
// LICENSE file in the root directory of this source tree.
6+
7+
#pragma once
8+
9+
#include <fstream>
10+
#include <string>
11+
#include "AVIOContextHolder.h"
12+
13+
namespace facebook::torchcodec {
14+
15+
// For reading from a file on disk. Unlike the other AVIOContextHolder
16+
// subclasses, this one does NOT create an FFmpeg AVIOContext — it only
17+
// provides the read/seek/getSize primitives for consumers like
18+
// WavDecoder that do their own parsing.
19+
class AVIOFileContext : public AVIOContextHolder {
20+
public:
21+
explicit AVIOFileContext(const std::string& path);
22+
23+
int read(uint8_t* buf, int size) override;
24+
int64_t seek(int64_t offset, int whence) override;
25+
int64_t getSize() override;
26+
27+
private:
28+
std::ifstream file_;
29+
int64_t fileSize_;
30+
};
31+
32+
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOFileLikeContext.cpp

Lines changed: 51 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,54 @@ AVIOFileLikeContext::AVIOFileLikeContext(
3333
py::hasattr(fileLike, "seek"),
3434
"File like object must implement a seek method.");
3535
}
36-
createAVIOContext(&read, &write, &seek, &fileLike_, isForWriting);
36+
createAVIOContext(
37+
&readCallback, &writeCallback, &seekCallback, this, isForWriting);
3738
}
3839

39-
int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
40-
auto fileLike = static_cast<UniquePyObject*>(opaque);
40+
int AVIOFileLikeContext::readCallback(
41+
void* opaque,
42+
uint8_t* buf,
43+
int buf_size) {
44+
auto self = static_cast<AVIOFileLikeContext*>(opaque);
45+
int result = self->read(buf, buf_size);
46+
return result < 0 ? AVERROR_EOF : result;
47+
}
4148

42-
// Note that we acquire the GIL outside of the loop. This is likely more
43-
// efficient than releasing and acquiring it each loop iteration.
49+
int64_t
50+
AVIOFileLikeContext::seekCallback(void* opaque, int64_t offset, int whence) {
51+
if (whence == AVSEEK_SIZE) {
52+
// Size of file-like is typically unknown, since the data is potentially
53+
// streaming.
54+
return AVERROR(EIO);
55+
}
56+
auto self = static_cast<AVIOFileLikeContext*>(opaque);
57+
return self->seek(offset, whence);
58+
}
59+
60+
int AVIOFileLikeContext::writeCallback(
61+
void* opaque,
62+
const uint8_t* buf,
63+
int buf_size) {
64+
auto self = static_cast<AVIOFileLikeContext*>(opaque);
65+
py::gil_scoped_acquire gil;
66+
py::bytes bytes_obj(reinterpret_cast<const char*>(buf), buf_size);
67+
return py::cast<int>(self->fileLike_->attr("write")(bytes_obj));
68+
}
69+
70+
int AVIOFileLikeContext::read(uint8_t* buf, int size) {
4471
py::gil_scoped_acquire gil;
4572

4673
int totalNumRead = 0;
47-
while (totalNumRead < buf_size) {
48-
int request = buf_size - totalNumRead;
49-
50-
// The Python method returns the actual bytes, which we access through the
51-
// py::bytes wrapper. That wrapper, however, does not provide us access to
52-
// the underlying data pointer, which we need for the memcpy below. So we
53-
// convert the bytes to a string_view to get access to the data pointer.
54-
// Becauase it's a view and not a copy, it should be cheap.
55-
auto bytesRead = static_cast<py::bytes>((*fileLike)->attr("read")(request));
74+
while (totalNumRead < size) {
75+
int request = size - totalNumRead;
76+
77+
// The Python method returns the actual bytes, which we access through
78+
// the py::bytes wrapper. That wrapper, however, does not provide us
79+
// access to the underlying data pointer, which we need for the memcpy
80+
// below. So we convert the bytes to a string_view to get access to
81+
// the data pointer. Because it's a view and not a copy, it should be
82+
// cheap.
83+
auto bytesRead = static_cast<py::bytes>(fileLike_->attr("read")(request));
5684
auto bytesView = static_cast<std::string_view>(bytesRead);
5785

5886
int numBytesRead = static_cast<int>(bytesView.size());
@@ -66,33 +94,26 @@ int AVIOFileLikeContext::read(void* opaque, uint8_t* buf, int buf_size) {
6694
request,
6795
" bytes but, received ",
6896
numBytesRead,
69-
" bytes. The given object does not conform to read protocol of file object.");
97+
" bytes. The given object does not conform to read protocol "
98+
"of file object.");
7099

71100
std::memcpy(buf, bytesView.data(), numBytesRead);
72101
buf += numBytesRead;
73102
totalNumRead += numBytesRead;
74103
}
75104

76-
return totalNumRead == 0 ? AVERROR_EOF : totalNumRead;
105+
return totalNumRead == 0 ? -1 : totalNumRead;
77106
}
78107

79-
int64_t AVIOFileLikeContext::seek(void* opaque, int64_t offset, int whence) {
80-
// We do not know the file size.
81-
if (whence == AVSEEK_SIZE) {
82-
return AVERROR(EIO);
83-
}
84-
85-
auto fileLike = static_cast<UniquePyObject*>(opaque);
108+
int64_t AVIOFileLikeContext::seek(int64_t offset, int whence) {
86109
py::gil_scoped_acquire gil;
87-
return py::cast<int64_t>((*fileLike)->attr("seek")(offset, whence));
110+
return py::cast<int64_t>(fileLike_->attr("seek")(offset, whence));
88111
}
89112

90-
int AVIOFileLikeContext::write(void* opaque, const uint8_t* buf, int buf_size) {
91-
auto fileLike = static_cast<UniquePyObject*>(opaque);
92-
py::gil_scoped_acquire gil;
93-
py::bytes bytes_obj(reinterpret_cast<const char*>(buf), buf_size);
94-
95-
return py::cast<int>((*fileLike)->attr("write")(bytes_obj));
113+
int64_t AVIOFileLikeContext::getSize() {
114+
// Size of file-like is typically unknown, since the data is potentially
115+
// streaming.
116+
return INT64_MAX;
96117
}
97118

98119
} // namespace facebook::torchcodec

src/torchcodec/_core/AVIOFileLikeContext.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,14 @@ class AVIOFileLikeContext : public AVIOContextHolder {
2727
public:
2828
explicit AVIOFileLikeContext(const py::object& fileLike, bool isForWriting);
2929

30+
int read(uint8_t* buf, int size) override;
31+
int64_t seek(int64_t offset, int whence) override;
32+
int64_t getSize() override;
33+
3034
private:
31-
static int read(void* opaque, uint8_t* buf, int buf_size);
32-
static int64_t seek(void* opaque, int64_t offset, int whence);
33-
static int write(void* opaque, const uint8_t* buf, int buf_size);
35+
static int readCallback(void* opaque, uint8_t* buf, int buf_size);
36+
static int64_t seekCallback(void* opaque, int64_t offset, int whence);
37+
static int writeCallback(void* opaque, const uint8_t* buf, int buf_size);
3438

3539
// Note that we dynamically allocate the Python object because we need to
3640
// strictly control when its destructor is called. We must hold the GIL

0 commit comments

Comments
 (0)