@@ -583,9 +583,9 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {
583
583
VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal (
584
584
std::optional<torch::Tensor> preAllocatedOutputTensor) {
585
585
validateActiveStream ();
586
- AVFrameStream avFrameStream = decodeAVFrame (
587
- [this ](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
588
- return convertAVFrameToFrameOutput (avFrameStream , preAllocatedOutputTensor);
586
+ UniqueAVFrame avFrame = decodeAVFrame (
587
+ [this ](const UniqueAVFrame& avFrame) { return avFrame->pts >= cursor_; });
588
+ return convertAVFrameToFrameOutput (avFrame , preAllocatedOutputTensor);
589
589
}
590
590
591
591
VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex (int64_t frameIndex) {
@@ -715,8 +715,8 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
715
715
}
716
716
717
717
setCursorPtsInSeconds (seconds);
718
- AVFrameStream avFrameStream =
719
- decodeAVFrame ([seconds, this ](AVFrame* avFrame) {
718
+ UniqueAVFrame avFrame =
719
+ decodeAVFrame ([seconds, this ](const UniqueAVFrame& avFrame) {
720
720
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
721
721
double frameStartTime = ptsToSeconds (avFrame->pts , streamInfo.timeBase );
722
722
double frameEndTime = ptsToSeconds (
@@ -735,7 +735,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) {
735
735
});
736
736
737
737
// Convert the frame to tensor.
738
- FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrameStream );
738
+ FrameOutput frameOutput = convertAVFrameToFrameOutput (avFrame );
739
739
frameOutput.data = maybePermuteHWC2CHW (frameOutput.data );
740
740
return frameOutput;
741
741
}
@@ -891,14 +891,15 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
891
891
auto finished = false ;
892
892
while (!finished) {
893
893
try {
894
- AVFrameStream avFrameStream = decodeAVFrame ([startPts](AVFrame* avFrame) {
895
- return startPts < avFrame->pts + getDuration (avFrame);
896
- });
894
+ UniqueAVFrame avFrame =
895
+ decodeAVFrame ([startPts](const UniqueAVFrame& avFrame) {
896
+ return startPts < avFrame->pts + getDuration (avFrame);
897
+ });
897
898
// TODO: it's not great that we are getting a FrameOutput, which is
898
899
// intended for videos. We should consider bypassing
899
900
// convertAVFrameToFrameOutput and directly call
900
901
// convertAudioAVFrameToFrameOutputOnCPU.
901
- auto frameOutput = convertAVFrameToFrameOutput (avFrameStream );
902
+ auto frameOutput = convertAVFrameToFrameOutput (avFrame );
902
903
firstFramePtsSeconds =
903
904
std::min (firstFramePtsSeconds, frameOutput.ptsSeconds );
904
905
frames.push_back (frameOutput.data );
@@ -1035,8 +1036,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
1035
1036
// LOW-LEVEL DECODING
1036
1037
// --------------------------------------------------------------------------
1037
1038
1038
- VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame (
1039
- std::function<bool (AVFrame* )> filterFunction) {
1039
+ UniqueAVFrame VideoDecoder::decodeAVFrame (
1040
+ std::function<bool (const UniqueAVFrame& )> filterFunction) {
1040
1041
validateActiveStream ();
1041
1042
1042
1043
resetDecodeStats ();
@@ -1064,7 +1065,7 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1064
1065
1065
1066
decodeStats_.numFramesReceivedByDecoder ++;
1066
1067
// Is this the kind of frame we're looking for?
1067
- if (status == AVSUCCESS && filterFunction (avFrame. get () )) {
1068
+ if (status == AVSUCCESS && filterFunction (avFrame)) {
1068
1069
// Yes, this is the frame we'll return; break out of the decoding loop.
1069
1070
break ;
1070
1071
} else if (status == AVSUCCESS) {
@@ -1150,37 +1151,36 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
1150
1151
streamInfo.lastDecodedAvFramePts = avFrame->pts ;
1151
1152
streamInfo.lastDecodedAvFrameDuration = getDuration (avFrame);
1152
1153
1153
- return AVFrameStream ( std::move ( avFrame), activeStreamIndex_) ;
1154
+ return avFrame;
1154
1155
}
1155
1156
1156
1157
// --------------------------------------------------------------------------
1157
1158
// AVFRAME <-> FRAME OUTPUT CONVERSION
1158
1159
// --------------------------------------------------------------------------
1159
1160
1160
1161
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput (
1161
- VideoDecoder::AVFrameStream& avFrameStream ,
1162
+ UniqueAVFrame& avFrame ,
1162
1163
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1163
1164
// Convert the frame to tensor.
1164
1165
FrameOutput frameOutput;
1165
- int streamIndex = avFrameStream.streamIndex ;
1166
- AVFrame* avFrame = avFrameStream.avFrame .get ();
1167
- frameOutput.streamIndex = streamIndex;
1168
- auto & streamInfo = streamInfos_[streamIndex];
1166
+ frameOutput.streamIndex = activeStreamIndex_;
1167
+ auto & streamInfo = streamInfos_[activeStreamIndex_];
1169
1168
frameOutput.ptsSeconds = ptsToSeconds (
1170
- avFrame->pts , formatContext_->streams [streamIndex ]->time_base );
1169
+ avFrame->pts , formatContext_->streams [activeStreamIndex_ ]->time_base );
1171
1170
frameOutput.durationSeconds = ptsToSeconds (
1172
- getDuration (avFrame), formatContext_->streams [streamIndex]->time_base );
1171
+ getDuration (avFrame),
1172
+ formatContext_->streams [activeStreamIndex_]->time_base );
1173
1173
if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) {
1174
1174
convertAudioAVFrameToFrameOutputOnCPU (
1175
- avFrameStream , frameOutput, preAllocatedOutputTensor);
1175
+ avFrame , frameOutput, preAllocatedOutputTensor);
1176
1176
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCPU ) {
1177
1177
convertAVFrameToFrameOutputOnCPU (
1178
- avFrameStream , frameOutput, preAllocatedOutputTensor);
1178
+ avFrame , frameOutput, preAllocatedOutputTensor);
1179
1179
} else if (streamInfo.videoStreamOptions .device .type () == torch::kCUDA ) {
1180
1180
convertAVFrameToFrameOutputOnCuda (
1181
1181
streamInfo.videoStreamOptions .device ,
1182
1182
streamInfo.videoStreamOptions ,
1183
- avFrameStream ,
1183
+ avFrame ,
1184
1184
frameOutput,
1185
1185
preAllocatedOutputTensor);
1186
1186
} else {
@@ -1201,14 +1201,13 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
1201
1201
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
1202
1202
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
1203
1203
void VideoDecoder::convertAVFrameToFrameOutputOnCPU (
1204
- VideoDecoder::AVFrameStream& avFrameStream ,
1204
+ UniqueAVFrame& avFrame ,
1205
1205
FrameOutput& frameOutput,
1206
1206
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1207
- AVFrame* avFrame = avFrameStream.avFrame .get ();
1208
1207
auto & streamInfo = streamInfos_[activeStreamIndex_];
1209
1208
1210
1209
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame (
1211
- streamInfo.videoStreamOptions , * avFrame);
1210
+ streamInfo.videoStreamOptions , avFrame);
1212
1211
int expectedOutputHeight = frameDims.height ;
1213
1212
int expectedOutputWidth = frameDims.width ;
1214
1213
@@ -1302,7 +1301,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
1302
1301
}
1303
1302
1304
1303
int VideoDecoder::convertAVFrameToTensorUsingSwsScale (
1305
- const AVFrame* avFrame,
1304
+ const UniqueAVFrame& avFrame,
1306
1305
torch::Tensor& outputTensor) {
1307
1306
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
1308
1307
SwsContext* swsContext = activeStreamInfo.swsContext .get ();
@@ -1322,11 +1321,11 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
1322
1321
}
1323
1322
1324
1323
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph (
1325
- const AVFrame* avFrame) {
1324
+ const UniqueAVFrame& avFrame) {
1326
1325
FilterGraphContext& filterGraphContext =
1327
1326
streamInfos_[activeStreamIndex_].filterGraphContext ;
1328
1327
int status =
1329
- av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame);
1328
+ av_buffersrc_write_frame (filterGraphContext.sourceContext , avFrame. get () );
1330
1329
if (status < AVSUCCESS) {
1331
1330
throw std::runtime_error (" Failed to add frame to buffer source context" );
1332
1331
}
@@ -1350,25 +1349,25 @@ torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
1350
1349
}
1351
1350
1352
1351
void VideoDecoder::convertAudioAVFrameToFrameOutputOnCPU (
1353
- VideoDecoder::AVFrameStream& avFrameStream ,
1352
+ UniqueAVFrame& srcAVFrame ,
1354
1353
FrameOutput& frameOutput,
1355
1354
std::optional<torch::Tensor> preAllocatedOutputTensor) {
1356
1355
TORCH_CHECK (
1357
1356
!preAllocatedOutputTensor.has_value (),
1358
1357
" pre-allocated audio tensor not supported yet." );
1359
1358
1360
1359
AVSampleFormat sourceSampleFormat =
1361
- static_cast <AVSampleFormat>(avFrameStream. avFrame ->format );
1360
+ static_cast <AVSampleFormat>(srcAVFrame ->format );
1362
1361
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;
1363
1362
1364
1363
UniqueAVFrame convertedAVFrame;
1365
1364
if (sourceSampleFormat != desiredSampleFormat) {
1366
1365
convertedAVFrame = convertAudioAVFrameSampleFormat (
1367
- avFrameStream. avFrame , sourceSampleFormat, desiredSampleFormat);
1366
+ srcAVFrame , sourceSampleFormat, desiredSampleFormat);
1368
1367
}
1369
1368
const UniqueAVFrame& avFrame = (sourceSampleFormat != desiredSampleFormat)
1370
1369
? convertedAVFrame
1371
- : avFrameStream. avFrame ;
1370
+ : srcAVFrame ;
1372
1371
1373
1372
AVSampleFormat format = static_cast <AVSampleFormat>(avFrame->format );
1374
1373
TORCH_CHECK (
@@ -1944,10 +1943,10 @@ FrameDims getHeightAndWidthFromOptionsOrMetadata(
1944
1943
1945
1944
FrameDims getHeightAndWidthFromOptionsOrAVFrame (
1946
1945
const VideoDecoder::VideoStreamOptions& videoStreamOptions,
1947
- const AVFrame & avFrame) {
1946
+ const UniqueAVFrame & avFrame) {
1948
1947
return FrameDims (
1949
- videoStreamOptions.height .value_or (avFrame. height ),
1950
- videoStreamOptions.width .value_or (avFrame. width ));
1948
+ videoStreamOptions.height .value_or (avFrame-> height ),
1949
+ videoStreamOptions.width .value_or (avFrame-> width ));
1951
1950
}
1952
1951
1953
1952
} // namespace facebook::torchcodec
0 commit comments