-
Notifications
You must be signed in to change notification settings - Fork 113
feat: support audio modal input & refactor media decoder. #682
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for audio input and refactors the media decoding logic by centralizing FFmpeg-based decoding for video and audio into a MemoryMediaReaderBase class and its derivatives. This is a significant improvement for code structure and maintainability. The changes also include adding an AudioHandler and updating related data structures to support audio. My review has identified a potential bug in the audio decoding logic that could cause decoding to fail for valid audio files, and a significant naming inconsistency that could be misleading for future maintenance. Overall, these are positive changes that enhance the project's capabilities.
| bool OpenCVVideoDecoder::decode(const std::string& raw_data, | ||
| torch::Tensor& t, | ||
| VideoMetadata& metadata) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The class name OpenCVVideoDecoder is misleading. The implementation, both before and after this refactoring, uses FFmpeg for video decoding, not OpenCV. This is confusing, especially since there is an OpenCVImageDecoder that correctly uses OpenCV. For better code clarity and maintainability, consider renaming OpenCVVideoDecoder to FFmpegVideoDecoder across the codebase.
43976d1 to
0bbd498
Compare
| }; | ||
|
|
||
| struct Reader { | ||
| static int read(void* opaque, uint8_t* buf, int buf_size) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specify all int to int32_t or int64_t.
| if (pkt_) av_packet_free(&pkt_); | ||
| if (codec_ctx_) avcodec_free_context(&codec_ctx_); | ||
| if (fmt_ctx_) { | ||
| if (opened_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add {} for if.
| return true; | ||
| } | ||
|
|
||
| bool decode_all() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this func is not good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you should use AI to refactor all class and func in this files and add commetns.
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for audio modal input and refactors the media decoding logic. While this refactoring centralizes media processing, it also introduces significant security risks, including a critical Remote Code Execution vulnerability via the FFmpeg library when parsing untrusted media files, and a high-risk Server-Side Request Forgery (SSRF) vulnerability when fetching media from user-provided URLs. Additionally, critical issues related to resource management, logical correctness in handlers, and processing of multi-modal inputs need to be addressed.
| bool OpenCVVideoDecoder::decode(const std::string& raw_data, | ||
| torch::Tensor& t, | ||
| VideoMetadata& metadata) { | ||
| struct MemCtx { | ||
| const uint8_t* p; | ||
| size_t sz; | ||
| size_t off; | ||
| }; | ||
|
|
||
| struct Reader { | ||
| static int read(void* opaque, uint8_t* buf, int buf_size) { | ||
| auto* mc = static_cast<MemCtx*>(opaque); | ||
| size_t remain = mc->sz - mc->off; | ||
| int n = (int)std::min(remain, (size_t)buf_size); | ||
| if (n <= 0) return AVERROR_EOF; | ||
| memcpy(buf, mc->p + mc->off, n); | ||
| mc->off += (size_t)n; | ||
| return n; | ||
| } | ||
|
|
||
| static int64_t seek(void* opaque, int64_t offset, int whence) { | ||
| auto* mc = static_cast<MemCtx*>(opaque); | ||
|
|
||
| if (whence == AVSEEK_SIZE) { | ||
| return (int64_t)mc->sz; | ||
| } | ||
|
|
||
| int64_t pos = 0; | ||
| switch (whence) { | ||
| case SEEK_SET: | ||
| pos = offset; | ||
| break; | ||
| case SEEK_CUR: | ||
| pos = (int64_t)mc->off + offset; | ||
| break; | ||
| case SEEK_END: | ||
| pos = (int64_t)mc->sz + offset; | ||
| break; | ||
| default: | ||
| return AVERROR(EINVAL); | ||
| } | ||
|
|
||
| if (pos < 0 || pos > (int64_t)mc->sz) return AVERROR_EOF; | ||
|
|
||
| mc->off = (size_t)pos; | ||
| return pos; | ||
| } | ||
| }; | ||
|
|
||
| AVFormatContext* fmt_ctx = avformat_alloc_context(); | ||
| const int avio_buf_sz = 1 << 16; | ||
| uint8_t* avio_buf = (uint8_t*)av_malloc(avio_buf_sz); | ||
| if (!fmt_ctx || !avio_buf) { | ||
| if (fmt_ctx) avformat_free_context(fmt_ctx); | ||
| if (avio_buf) av_free(avio_buf); | ||
| return false; | ||
| } | ||
|
|
||
| MemCtx mc{(const uint8_t*)raw_data.data(), raw_data.size(), 0}; | ||
|
|
||
| AVIOContext* avio_ctx = avio_alloc_context( | ||
| avio_buf, avio_buf_sz, 0, &mc, &Reader::read, nullptr, &Reader::seek); | ||
| if (!avio_ctx) { | ||
| av_free(avio_buf); | ||
| avformat_free_context(fmt_ctx); | ||
| return false; | ||
| } | ||
| MemoryVideoReader reader(reinterpret_cast<const uint8_t*>(raw_data.data()), | ||
| raw_data.size()); | ||
|
|
||
| avio_ctx->seekable = AVIO_SEEKABLE_NORMAL; | ||
|
|
||
| fmt_ctx->pb = avio_ctx; | ||
| fmt_ctx->flags |= AVFMT_FLAG_CUSTOM_IO; | ||
| fmt_ctx->probesize = std::min<size_t>(raw_data.size(), 20 * 1024 * 1024); | ||
| fmt_ctx->max_analyze_duration = 5LL * AV_TIME_BASE; | ||
|
|
||
| bool ok = false; | ||
|
|
||
| if (avformat_open_input(&fmt_ctx, nullptr, nullptr, nullptr) < 0) { | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_free_context(fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| if (avformat_find_stream_info(fmt_ctx, nullptr) < 0) { | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| int vs = av_find_best_stream(fmt_ctx, AVMEDIA_TYPE_VIDEO, -1, -1, nullptr, 0); | ||
| if (vs < 0) { | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| AVStream* st = fmt_ctx->streams[vs]; | ||
| AVCodecParameters* par = st->codecpar; | ||
| const AVCodec* dec = avcodec_find_decoder(par->codec_id); | ||
| if (!dec) { | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| AVCodecContext* codec_ctx = avcodec_alloc_context3(dec); | ||
| if (!codec_ctx) { | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| if (avcodec_parameters_to_context(codec_ctx, par) < 0 || | ||
| avcodec_open2(codec_ctx, dec, nullptr) < 0) { | ||
| avcodec_free_context(&codec_ctx); | ||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| return false; | ||
| } | ||
|
|
||
| AVRational r = st->avg_frame_rate.num ? st->avg_frame_rate : st->r_frame_rate; | ||
| double fps = (r.num && r.den) ? av_q2d(r) : 0.0; | ||
| metadata.fps = fps; | ||
|
|
||
| SwsContext* sws = nullptr; | ||
| AVPacket* pkt = av_packet_alloc(); | ||
| AVFrame* frm = av_frame_alloc(); | ||
| std::vector<torch::Tensor> frames; | ||
|
|
||
| auto push_frame = [&](AVFrame* f) -> bool { | ||
| if (!sws) { | ||
| sws = sws_getContext(f->width, | ||
| f->height, | ||
| (AVPixelFormat)f->format, | ||
| f->width, | ||
| f->height, | ||
| AV_PIX_FMT_RGB24, | ||
| SWS_BILINEAR, | ||
| nullptr, | ||
| nullptr, | ||
| nullptr); | ||
| if (!sws) return false; | ||
| } | ||
|
|
||
| torch::Tensor rgb = torch::empty({f->height, f->width, 3}, torch::kUInt8); | ||
| uint8_t* dst_data[4] = {rgb.data_ptr<uint8_t>(), nullptr, nullptr, nullptr}; | ||
| int dst_linesize[4] = {(int)rgb.stride(0), 0, 0, 0}; | ||
|
|
||
| sws_scale(sws, f->data, f->linesize, 0, f->height, dst_data, dst_linesize); | ||
|
|
||
| frames.emplace_back(rgb.permute({2, 0, 1}).clone()); // [C,H,W] | ||
| return true; | ||
| }; | ||
|
|
||
| while (av_read_frame(fmt_ctx, pkt) >= 0) { | ||
| if (pkt->stream_index == vs) { | ||
| if (avcodec_send_packet(codec_ctx, pkt) == 0) { | ||
| while (avcodec_receive_frame(codec_ctx, frm) == 0) { | ||
| if (!push_frame(frm)) break; | ||
| } | ||
| } | ||
| } | ||
| av_packet_unref(pkt); | ||
| } | ||
|
|
||
| // flush | ||
| avcodec_send_packet(codec_ctx, nullptr); | ||
| while (avcodec_receive_frame(codec_ctx, frm) == 0) { | ||
| if (!push_frame(frm)) break; | ||
| } | ||
|
|
||
| if (!frames.empty()) { | ||
| t = torch::stack(frames); // [T,C,H,W] | ||
| metadata.total_num_frames = static_cast<int64_t>(frames.size()); | ||
| if (metadata.fps > 0.0) { | ||
| metadata.duration = metadata.total_num_frames / metadata.fps; | ||
| } else { | ||
| metadata.duration = 0.0; | ||
| } | ||
| ok = true; | ||
| } | ||
|
|
||
| if (sws) sws_freeContext(sws); | ||
| av_frame_free(&frm); | ||
| av_packet_free(&pkt); | ||
| avcodec_free_context(&codec_ctx); | ||
| if (!reader.init(metadata)) return false; | ||
| if (!reader.read_all(t, metadata)) return false; | ||
| return true; | ||
| } | ||
|
|
||
| av_freep(&avio_ctx->buffer); | ||
| avio_context_free(&avio_ctx); | ||
| avformat_close_input(&fmt_ctx); | ||
| bool FFmpegAudioDecoder::decode(const std::string& raw_data, | ||
| torch::Tensor& t, | ||
| AudioMetadata& metadata) { | ||
| MemoryAudioReader reader(reinterpret_cast<const uint8_t*>(raw_data.data()), | ||
| raw_data.size()); | ||
|
|
||
| return ok; | ||
| if (!reader.init(metadata)) return false; | ||
| if (!reader.read_all(t, metadata)) return false; | ||
| return true; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions OpenCVVideoDecoder::decode and FFmpegAudioDecoder::decode process raw media data (raw_data) using the FFmpeg library. This data is sourced from user-provided URLs or direct uploads and must be considered untrusted. The FFmpeg library has a history of critical memory corruption vulnerabilities (such as buffer overflows, integer overflows, and use-after-frees) that can be triggered by parsing specially crafted media files. An attacker could upload a malicious video or audio file, which, when processed by FFmpeg on the server, could lead to Remote Code Execution (RCE). The C++ wrapper code, while performing some checks, cannot prevent the exploitation of vulnerabilities within the FFmpeg library itself.
Remediation:
- Sandboxing: The most effective mitigation is to run the media decoding process in a tightly sandboxed, isolated environment (e.g., a separate container with minimal privileges, or using gVisor/seccomp).
- Library Versioning: Ensure that the FFmpeg library is kept up-to-date with the latest security patches. Implement a process for monitoring FFmpeg vulnerabilities and applying updates promptly.
| if (input_item.has_type(MMType::IMAGE)) { | ||
| if (input_item.decode_image_.defined()) { | ||
| images.push_back(input_item.decode_image_); | ||
| } else if (input_item.embedding_.embedding.defined()) { | ||
| images_embedding.push_back(input_item.embedding_); | ||
| } | ||
| } else if (input_item.type_ == MMType::VIDEO) { | ||
| if (input_item.decode_data_.defined()) { | ||
| videos.push_back(input_item.decode_data_); | ||
| } else if (input_item.has_type(MMType::VIDEO)) { | ||
| if (input_item.decode_video_.defined()) { | ||
| videos.push_back(input_item.decode_video_); | ||
| } | ||
| video_meta_list.push_back(input_item.video_meta_); | ||
| } else if (input_item.has_type(MMType::AUDIO)) { | ||
| if (input_item.decode_audio_.defined()) { | ||
| audios.push_back(input_item.decode_audio_); | ||
| } | ||
| audio_meta_list.push_back(input_item.audio_meta_); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if-else if structure for checking the media type of an MMInputItem is incorrect for handling items that contain multiple modalities, such as a video with an audio track. Since VideoHandler can produce an item with type_ being MMType::VIDEO | MMType::AUDIO, this logic will only process the video part and silently ignore the audio part. This will lead to data loss. You should use separate if statements for each modality to handle combined types correctly.
if (input_item.has_type(MMType::IMAGE)) {
if (input_item.decode_image_.defined()) {
images.push_back(input_item.decode_image_);
} else if (input_item.embedding_.embedding.defined()) {
images_embedding.push_back(input_item.embedding_);
}
}
if (input_item.has_type(MMType::VIDEO)) {
if (input_item.decode_video_.defined()) {
videos.push_back(input_item.decode_video_);
}
video_meta_list.push_back(input_item.video_meta_);
}
if (input_item.has_type(MMType::AUDIO)) {
if (input_item.decode_audio_.defined()) {
audios.push_back(input_item.decode_audio_);
}
audio_meta_list.push_back(input_item.audio_meta_);
}| bool AudioHandler::load(const MMContent& content, | ||
| MMInputItem& input, | ||
| MMPayload& payload) { | ||
| input.clear(); | ||
|
|
||
| const auto& audio_url = content.audio_url; | ||
| const auto& url = audio_url.url; | ||
|
|
||
| if (url.compare(0, dataurl_prefix_.size(), dataurl_prefix_) == | ||
| 0) { // data url | ||
|
|
||
| input.type_ = MMType::AUDIO; | ||
| return this->load_from_dataurl(url, input.raw_data_, payload); | ||
| } else if (url.compare(0, httpurl_prefix_.size(), httpurl_prefix_) == | ||
| 0) { // http url | ||
|
|
||
| input.type_ = MMType::AUDIO; | ||
| return this->load_from_http(url, input.raw_data_); | ||
| } else { | ||
| LOG(ERROR) << " audio url is invalid, url is " << url; | ||
| return false; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AudioHandler::load function (and similarly ImageHandler::load and VideoHandler::load) retrieves data from a user-provided URL. The code checks if the URL starts with "http" before passing it to the load_from_http function, which in turn makes a request to that URL. This validation is insufficient to prevent Server-Side Request Forgery (SSRF) attacks. An attacker can provide a URL pointing to internal services or cloud metadata endpoints (e.g., http://127.0.0.1/admin, http://169.254.169.254/latest/meta-data/). This could allow an attacker to scan the internal network, access sensitive internal services, or steal cloud infrastructure credentials.
Remediation:
Implement a strict allow-list of trusted domains and IP ranges that the server is permitted to request. Requests to URLs outside of this allow-list should be blocked.
| av_channel_layout_uninit(&out_layout); | ||
| av_channel_layout_uninit(&in_layout); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a potential memory leak in this error handling path. swr_ctx_ is allocated with swr_alloc(), but if swr_alloc_set_opts2() fails, swr_ctx_ is not freed before returning. While the destructor would eventually clean this up, it's inconsistent with other error paths in this function that do call swr_free(). To ensure consistent and immediate resource cleanup, you should free the context here.
av_channel_layout_uninit(&out_layout);
av_channel_layout_uninit(&in_layout);
swr_free(&swr_ctx_);
return false;| bool VideoHandler::decode(MMInputItem& input) { | ||
| FFmpegAudioDecoder audio_decoder; | ||
| if (audio_decoder.decode( | ||
| input.raw_data_, input.decode_audio_, input.audio_meta_)) { | ||
| input.type_ |= MMType::AUDIO; | ||
| } | ||
|
|
||
| OpenCVVideoDecoder decoder; | ||
| return decoder.decode(input.raw_data_, input.decode_data_, input.video_meta_); | ||
| return decoder.decode( | ||
| input.raw_data_, input.decode_video_, input.video_meta_); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The return value of VideoHandler::decode depends solely on the success of video decoding, ignoring the result of audio decoding. This means if a video file contains a valid audio stream but an unsupported or corrupted video stream, the function will return false even though audio was successfully extracted. This will cause the processing of the item to fail and the extracted audio data to be discarded. The function should return true if at least one of the media types (audio or video) is successfully decoded.
bool VideoHandler::decode(MMInputItem& input) {
bool audio_decoded = false;
FFmpegAudioDecoder audio_decoder;
if (audio_decoder.decode(
input.raw_data_, input.decode_audio_, input.audio_meta_)) {
input.type_ |= MMType::AUDIO;
audio_decoded = true;
}
OpenCVVideoDecoder video_decoder;
bool video_decoded = video_decoder.decode(
input.raw_data_, input.decode_video_, input.video_meta_);
return audio_decoded || video_decoded;
}
No description provided.