Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchAudio][stream_reader] Make StreamingMediaDecoderBytes available for C++ usage #3742

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/libtorio/ffmpeg/ffmpeg.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ struct StreamParams {
AVRational time_base{};
int stream_index{};
};

struct BytesWrapper {
std::string_view src;
size_t index = 0;
};

} // namespace io
} // namespace torio

Expand Down
56 changes: 0 additions & 56 deletions src/libtorio/ffmpeg/pybind/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,62 +188,6 @@ struct StreamingMediaEncoderFileObj : private FileObj,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
};

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder/Encoder Bytes
//////////////////////////////////////////////////////////////////////////////
struct BytesWrapper {
std::string_view src;
size_t index = 0;
};

static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);

auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
auto head = wrapper->src.data() + wrapper->index;
memcpy(buf, head, num_read);
wrapper->index += num_read;
return num_read;
}

static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
if (whence == AVSEEK_SIZE) {
return wrapper->src.size();
}

if (whence == SEEK_SET) {
wrapper->index = offset;
} else if (whence == SEEK_CUR) {
wrapper->index += offset;
} else if (whence == SEEK_END) {
wrapper->index = wrapper->src.size() + offset;
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
}
return static_cast<int64_t>(wrapper->index);
}

struct StreamingMediaDecoderBytes : private BytesWrapper,
public StreamingMediaDecoderCustomIO {
StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: BytesWrapper{src},
StreamingMediaDecoderCustomIO(
this,
format,
buffer_size,
read_bytes,
seek_bytes,
option) {}
};

#ifndef TORIO_FFMPEG_EXT_NAME
#error TORIO_FFMPEG_EXT_NAME must be defined.
#endif
Expand Down
49 changes: 49 additions & 0 deletions src/libtorio/ffmpeg/stream_reader/stream_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -610,4 +610,53 @@ StreamingMediaDecoderCustomIO::StreamingMediaDecoderCustomIO(
: CustomInput(opaque, buffer_size, read_packet, seek),
StreamingMediaDecoder(io_ctx, format, option) {}

namespace {
static int read_bytes(void* opaque, uint8_t* buf, int buf_size) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);

auto num_read = FFMIN(wrapper->src.size() - wrapper->index, buf_size);
if (num_read == 0) {
return AVERROR_EOF;
}
auto head = wrapper->src.data() + wrapper->index;
memcpy(buf, head, num_read);
wrapper->index += num_read;
return num_read;
}

static int64_t seek_bytes(void* opaque, int64_t offset, int whence) {
BytesWrapper* wrapper = static_cast<BytesWrapper*>(opaque);
if (whence == AVSEEK_SIZE) {
return wrapper->src.size();
}

if (whence == SEEK_SET) {
wrapper->index = offset;
} else if (whence == SEEK_CUR) {
wrapper->index += offset;
} else if (whence == SEEK_END) {
wrapper->index = wrapper->src.size() + offset;
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected whence value: ", whence);
}
return static_cast<int64_t>(wrapper->index);
}
} // namespace

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder Bytes
//////////////////////////////////////////////////////////////////////////////
StreamingMediaDecoderBytes::StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: BytesWrapper{src},
StreamingMediaDecoderCustomIO(
this,
format,
buffer_size,
read_bytes,
seek_bytes,
option) {}
} // namespace torio::io
22 changes: 22 additions & 0 deletions src/libtorio/ffmpeg/stream_reader/stream_reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,28 @@ class StreamingMediaDecoderCustomIO : private detail::CustomInput,
const c10::optional<OptionDict>& option = c10::nullopt);
};

//////////////////////////////////////////////////////////////////////////////
// StreamingMediaDecoder Bytes
//////////////////////////////////////////////////////////////////////////////
struct StreamingMediaDecoderBytes : private BytesWrapper,
public StreamingMediaDecoderCustomIO {
public:
///
/// Construct StreamingMediaDecoder with read and seek functions that read
/// from in memory buffer
///
/// @param src In memory bytes buffer
/// @param format Specify input format.
/// @param option Custom option passed when initializing format context.
/// @param buffer_size The size of the intermediate buffer, which FFmpeg uses
/// to pass data to function read_packet.
StreamingMediaDecoderBytes(
std::string_view src,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size);
};

// For BC
using StreamReader = StreamingMediaDecoder;
using StreamReaderCustomIO = StreamingMediaDecoderCustomIO;
Expand Down