Skip to content

Commit b830480

Browse files
authored
Remove AVFrameStream struct and remove raw AVFrame* pointers (#572)
1 parent 5713507 commit b830480

File tree

7 files changed

+52
-75
lines changed

7 files changed

+52
-75
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ namespace facebook::torchcodec {
1717
void convertAVFrameToFrameOutputOnCuda(
1818
const torch::Device& device,
1919
[[maybe_unused]] const VideoDecoder::VideoStreamOptions& videoStreamOptions,
20-
[[maybe_unused]] VideoDecoder::AVFrameStream& avFrameStream,
20+
[[maybe_unused]] UniqueAVFrame& avFrame,
2121
[[maybe_unused]] VideoDecoder::FrameOutput& frameOutput,
2222
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
2323
throwUnsupportedDeviceError(device);

src/torchcodec/decoders/_core/CudaDevice.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -190,17 +190,15 @@ void initializeContextOnCuda(
190190
void convertAVFrameToFrameOutputOnCuda(
191191
const torch::Device& device,
192192
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
193-
VideoDecoder::AVFrameStream& avFrameStream,
193+
UniqueAVFrame& avFrame,
194194
VideoDecoder::FrameOutput& frameOutput,
195195
std::optional<torch::Tensor> preAllocatedOutputTensor) {
196-
AVFrame* avFrame = avFrameStream.avFrame.get();
197-
198196
TORCH_CHECK(
199197
avFrame->format == AV_PIX_FMT_CUDA,
200198
"Expected format to be AV_PIX_FMT_CUDA, got " +
201199
std::string(av_get_pix_fmt_name((AVPixelFormat)avFrame->format)));
202200
auto frameDims =
203-
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, *avFrame);
201+
getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame);
204202
int height = frameDims.height;
205203
int width = frameDims.width;
206204
torch::Tensor& dst = frameOutput.data;

src/torchcodec/decoders/_core/DeviceInterface.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void initializeContextOnCuda(
3232
void convertAVFrameToFrameOutputOnCuda(
3333
const torch::Device& device,
3434
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
35-
VideoDecoder::AVFrameStream& avFrameStream,
35+
UniqueAVFrame& avFrame,
3636
VideoDecoder::FrameOutput& frameOutput,
3737
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
3838

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

+3-7
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,11 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode) {
4848
return std::string(errorBuffer);
4949
}
5050

51-
int64_t getDuration(const UniqueAVFrame& frame) {
52-
return getDuration(frame.get());
53-
}
54-
55-
int64_t getDuration(const AVFrame* frame) {
51+
int64_t getDuration(const UniqueAVFrame& avFrame) {
5652
#if LIBAVUTIL_VERSION_MAJOR < 58
57-
return frame->pkt_duration;
53+
return avFrame->pkt_duration;
5854
#else
59-
return frame->duration;
55+
return avFrame->duration;
6056
#endif
6157
}
6258

src/torchcodec/decoders/_core/FFMPEGCommon.h

-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
140140
// struct member representing duration has changed across the versions we
141141
// support.
142142
int64_t getDuration(const UniqueAVFrame& frame);
143-
int64_t getDuration(const AVFrame* frame);
144143

145144
int getNumChannels(const UniqueAVFrame& avFrame);
146145
int getNumChannels(const UniqueAVCodecContext& avCodecContext);

src/torchcodec/decoders/_core/VideoDecoder.cpp

+36-37
Original file line numberDiff line numberDiff line change
@@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
583583
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
584584
std::optional<torch::Tensor> preAllocatedOutputTensor) {
585585
validateActiveStream();
586-
AVFrameStream avFrameStream = decodeAVFrame(
587-
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
588-
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
586+
UniqueAVFrame avFrame = decodeAVFrame(
587+
[this](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
588+
return convertAVFrameToFrameOutput(avFrame, preAllocatedOutputTensor);
589589
}
590590

591591
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) {
@@ -715,8 +715,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
715715
}
716716

717717
setCursorPtsInSeconds(seconds);
718-
AVFrameStream avFrameStream =
719-
decodeAVFrame([seconds, this](AVFrame* avFrame) {
718+
UniqueAVFrame avFrame =
719+
decodeAVFrame([seconds, this](const UniqueAVFrame& avFrame) {
720720
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
721721
double frameStartTime = ptsToSeconds(avFrame->pts, streamInfo.timeBase);
722722
double frameEndTime = ptsToSeconds(
@@ -735,7 +735,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
735735
});
736736

737737
// Convert the frame to tensor.
738-
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream);
738+
FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrame);
739739
frameOutput.data = maybePermuteHWC2CHW(frameOutput.data);
740740
return frameOutput;
741741
}
@@ -891,14 +891,15 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
891891
auto finished = false;
892892
while (!finished) {
893893
try {
894-
AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) {
895-
return startPts < avFrame->pts + getDuration(avFrame);
896-
});
894+
UniqueAVFrame avFrame =
895+
decodeAVFrame([startPts](const UniqueAVFrame& avFrame) {
896+
return startPts < avFrame->pts + getDuration(avFrame);
897+
});
897898
// TODO: it's not great that we are getting a FrameOutput, which is
898899
// intended for videos. We should consider bypassing
899900
// convertAVFrameToFrameOutput and directly call
900901
// convertAudioAVFrameToFrameOutputOnCPU.
901-
auto frameOutput = convertAVFrameToFrameOutput(avFrameStream);
902+
auto frameOutput = convertAVFrameToFrameOutput(avFrame);
902903
firstFramePtsSeconds =
903904
std::min(firstFramePtsSeconds, frameOutput.ptsSeconds);
904905
frames.push_back(frameOutput.data);
@@ -1035,8 +1036,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
10351036
// LOW-LEVEL DECODING
10361037
// --------------------------------------------------------------------------
10371038

1038-
VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1039-
std::function<bool(AVFrame*)> filterFunction) {
1039+
UniqueAVFrame VideoDecoder::decodeAVFrame(
1040+
std::function<bool(const UniqueAVFrame&)> filterFunction) {
10401041
validateActiveStream();
10411042

10421043
resetDecodeStats();
@@ -1064,7 +1065,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
10641065

10651066
decodeStats_.numFramesReceivedByDecoder++;
10661067
// Is this the kind of frame we're looking for?
1067-
if (status == AVSUCCESS && filterFunction(avFrame.get())) {
1068+
if (status == AVSUCCESS && filterFunction(avFrame)) {
10681069
// Yes, this is the frame we'll return; break out of the decoding loop.
10691070
break;
10701071
} else if (status == AVSUCCESS) {
@@ -1150,37 +1151,36 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
11501151
streamInfo.lastDecodedAvFramePts = avFrame->pts;
11511152
streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame);
11521153

1153-
return AVFrameStream(std::move(avFrame), activeStreamIndex_);
1154+
return avFrame;
11541155
}
11551156

11561157
// --------------------------------------------------------------------------
11571158
// AVFRAME <-> FRAME OUTPUT CONVERSION
11581159
// --------------------------------------------------------------------------
11591160

11601161
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1161-
VideoDecoder::AVFrameStream& avFrameStream,
1162+
UniqueAVFrame& avFrame,
11621163
std::optional<torch::Tensor> preAllocatedOutputTensor) {
11631164
// Convert the frame to tensor.
11641165
FrameOutput frameOutput;
1165-
int streamIndex = avFrameStream.streamIndex;
1166-
AVFrame* avFrame = avFrameStream.avFrame.get();
1167-
frameOutput.streamIndex = streamIndex;
1168-
auto& streamInfo = streamInfos_[streamIndex];
1166+
frameOutput.streamIndex = activeStreamIndex_;
1167+
auto& streamInfo = streamInfos_[activeStreamIndex_];
11691168
frameOutput.ptsSeconds = ptsToSeconds(
1170-
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
1169+
avFrame->pts, formatContext_->streams[activeStreamIndex_]->time_base);
11711170
frameOutput.durationSeconds = ptsToSeconds(
1172-
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
1171+
getDuration(avFrame),
1172+
formatContext_->streams[activeStreamIndex_]->time_base);
11731173
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
11741174
convertAudioAVFrameToFrameOutputOnCPU(
1175-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1175+
avFrame, frameOutput, preAllocatedOutputTensor);
11761176
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
11771177
convertAVFrameToFrameOutputOnCPU(
1178-
avFrameStream, frameOutput, preAllocatedOutputTensor);
1178+
avFrame, frameOutput, preAllocatedOutputTensor);
11791179
} else if (streamInfo.videoStreamOptions.device.type() == torch::kCUDA) {
11801180
convertAVFrameToFrameOutputOnCuda(
11811181
streamInfo.videoStreamOptions.device,
11821182
streamInfo.videoStreamOptions,
1183-
avFrameStream,
1183+
avFrame,
11841184
frameOutput,
11851185
preAllocatedOutputTensor);
11861186
} else {
@@ -1201,14 +1201,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
12011201
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
12021202
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
12031203
void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1204-
VideoDecoder::AVFrameStream& avFrameStream,
1204+
UniqueAVFrame& avFrame,
12051205
FrameOutput& frameOutput,
12061206
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1207-
AVFrame* avFrame = avFrameStream.avFrame.get();
12081207
auto& streamInfo = streamInfos_[activeStreamIndex_];
12091208

12101209
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
1211-
streamInfo.videoStreamOptions, *avFrame);
1210+
streamInfo.videoStreamOptions, avFrame);
12121211
int expectedOutputHeight = frameDims.height;
12131212
int expectedOutputWidth = frameDims.width;
12141213

@@ -1302,7 +1301,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
13021301
}
13031302

13041303
int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
1305-
const AVFrame* avFrame,
1304+
const UniqueAVFrame& avFrame,
13061305
torch::Tensor& outputTensor) {
13071306
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
13081307
SwsContext* swsContext = activeStreamInfo.swsContext.get();
@@ -1322,11 +1321,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
13221321
}
13231322

13241323
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1325-
const AVFrame* avFrame) {
1324+
const UniqueAVFrame& avFrame) {
13261325
FilterGraphContext& filterGraphContext =
13271326
streamInfos_[activeStreamIndex_].filterGraphContext;
13281327
int status =
1329-
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame);
1328+
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame.get());
13301329
if (status < AVSUCCESS) {
13311330
throw std::runtime_error("Failed to add frame to buffer source context");
13321331
}
@@ -1350,25 +1349,25 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
13501349
}
13511350

13521351
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU(
1353-
VideoDecoder::AVFrameStream& avFrameStream,
1352+
UniqueAVFrame& srcAVFrame,
13541353
FrameOutput& frameOutput,
13551354
std::optional<torch::Tensor> preAllocatedOutputTensor) {
13561355
TORCH_CHECK(
13571356
!preAllocatedOutputTensor.has_value(),
13581357
"pre-allocated audio tensor not supported yet.");
13591358

13601359
AVSampleFormat sourceSampleFormat =
1361-
static_cast<AVSampleFormat>(avFrameStream.avFrame->format);
1360+
static_cast<AVSampleFormat>(srcAVFrame->format);
13621361
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
13631362

13641363
UniqueAVFrame convertedAVFrame;
13651364
if (sourceSampleFormat != desiredSampleFormat) {
13661365
convertedAVFrame = convertAudioAVFrameSampleFormat(
1367-
avFrameStream.avFrame, sourceSampleFormat, desiredSampleFormat);
1366+
srcAVFrame, sourceSampleFormat, desiredSampleFormat);
13681367
}
13691368
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
13701369
? convertedAVFrame
1371-
: avFrameStream.avFrame;
1370+
: srcAVFrame;
13721371

13731372
AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
13741373
TORCH_CHECK(
@@ -1944,10 +1943,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
19441943

19451944
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
19461945
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1947-
const AVFrame& avFrame) {
1946+
const UniqueAVFrame& avFrame) {
19481947
return FrameDims(
1949-
videoStreamOptions.height.value_or(avFrame.height),
1950-
videoStreamOptions.width.value_or(avFrame.width));
1948+
videoStreamOptions.height.value_or(avFrame->height),
1949+
videoStreamOptions.width.value_or(avFrame->width));
19511950
}
19521951

19531952
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.h

+9-24
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,6 @@ class VideoDecoder {
244244
// These are APIs that should be private, but that are effectively exposed for
245245
// practical reasons, typically for testing purposes.
246246

247-
// This struct is needed because AVFrame doesn't retain the streamIndex. Only
248-
// the AVPacket knows its stream. This is what the low-level private decoding
249-
// entry points return. The AVFrameStream is then converted to a FrameOutput
250-
// with convertAVFrameToFrameOutput. It should be private, but is currently
251-
// used by DeviceInterface.
252-
struct AVFrameStream {
253-
// The actual decoded output as a unique pointer to an AVFrame.
254-
// Usually, this is a YUV frame. It'll be converted to RGB in
255-
// convertAVFrameToFrameOutput.
256-
UniqueAVFrame avFrame;
257-
// The stream index of the decoded frame.
258-
int streamIndex;
259-
260-
explicit AVFrameStream(UniqueAVFrame&& a, int s)
261-
: avFrame(std::move(a)), streamIndex(s) {}
262-
};
263-
264247
// Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we
265248
// can move it back to private.
266249
FrameOutput getFrameAtIndexInternal(
@@ -376,31 +359,33 @@ class VideoDecoder {
376359

377360
void maybeSeekToBeforeDesiredPts();
378361

379-
AVFrameStream decodeAVFrame(std::function<bool(AVFrame*)> filterFunction);
362+
UniqueAVFrame decodeAVFrame(
363+
std::function<bool(const UniqueAVFrame&)> filterFunction);
380364

381365
FrameOutput getNextFrameInternal(
382366
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
383367

384368
torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor);
385369

386370
FrameOutput convertAVFrameToFrameOutput(
387-
AVFrameStream& avFrameStream,
371+
UniqueAVFrame& avFrame,
388372
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
389373

390374
void convertAVFrameToFrameOutputOnCPU(
391-
AVFrameStream& avFrameStream,
375+
UniqueAVFrame& avFrame,
392376
FrameOutput& frameOutput,
393377
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
394378

395379
void convertAudioAVFrameToFrameOutputOnCPU(
396-
AVFrameStream& avFrameStream,
380+
UniqueAVFrame& srcAVFrame,
397381
FrameOutput& frameOutput,
398382
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
399383

400-
torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame);
384+
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
385+
const UniqueAVFrame& avFrame);
401386

402387
int convertAVFrameToTensorUsingSwsScale(
403-
const AVFrame* avFrame,
388+
const UniqueAVFrame& avFrame,
404389
torch::Tensor& outputTensor);
405390

406391
UniqueAVFrame convertAudioAVFrameSampleFormat(
@@ -568,7 +553,7 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
568553

569554
FrameDims getHeightAndWidthFromOptionsOrAVFrame(
570555
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
571-
const AVFrame& avFrame);
556+
const UniqueAVFrame& avFrame);
572557

573558
torch::Tensor allocateEmptyHWCTensor(
574559
int height,

0 commit comments

Comments
 (0)