Skip to content

Commit 58093df

Browse files
authored
Filter NALs globally (#1403)
* Linux: Remove filter_NAL and avoid copy * Rework NAL filtering * Free packet before getting new one (@nowrep suggestion) * Remove unnecessary condition from `extractHeaders()` * Fix decoder init and properly offset the pointers * Use AMFBufferPtr instead of composing FramePacket * Add boundary checks
1 parent 0734d61 commit 58093df

File tree

6 files changed

+116
-131
lines changed

6 files changed

+116
-131
lines changed

alvr/server/cpp/alvr_server/ClientConnection.cpp

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,65 +8,99 @@
88
#include "Utils.h"
99
#include "Settings.h"
1010

11-
static const uint8_t NAL_TYPE_SPS = 7;
11+
static const char NAL_HEADER[] = {0x00, 0x00, 0x00, 0x01};
12+
13+
static const uint8_t H264_NAL_TYPE_SPS = 7;
1214
static const uint8_t H265_NAL_TYPE_VPS = 32;
1315

14-
ClientConnection::ClientConnection() {
15-
m_Statistics = std::make_shared<Statistics>();
16+
static const uint8_t H264_NAL_TYPE_AUD = 9;
17+
static const uint8_t H265_NAL_TYPE_AUD = 35;
18+
19+
ClientConnection::ClientConnection() {
20+
m_Statistics = std::make_shared<Statistics>();
1621
}
1722

18-
int findVPSSPS(const uint8_t *frameBuffer, int frameByteSize) {
19-
int zeroes = 0;
20-
int foundNals = 0;
21-
for (int i = 0; i < frameByteSize; i++) {
22-
if (frameBuffer[i] == 0) {
23-
zeroes++;
24-
} else if (frameBuffer[i] == 1) {
25-
if (zeroes >= 2) {
26-
foundNals++;
27-
if (Settings::Instance().m_codec == ALVR_CODEC_H264 && foundNals >= 3) {
28-
// Find end of SPS+PPS on H.264.
29-
return i - 3;
30-
} else if (Settings::Instance().m_codec == ALVR_CODEC_H265 && foundNals >= 4) {
31-
// Find end of VPS+SPS+PPS on H.264.
32-
return i - 3;
33-
}
34-
}
35-
zeroes = 0;
36-
} else {
37-
zeroes = 0;
38-
}
39-
}
40-
return -1;
23+
/*
24+
Sends the (VPS + )SPS + PPS video configuration headers from H.264 or H.265 stream as a sequence of NALs.
25+
(VPS + )SPS + PPS have short size (8bytes + 28bytes in some environment), so we can
26+
assume SPS + PPS is contained in first fragment.
27+
*/
28+
void sendHeaders(uint8_t **buf, int *len, int nalNum) {
29+
uint8_t *b = *buf;
30+
uint8_t *end = b + *len;
31+
32+
int headersLen = 0;
33+
int foundHeaders = -1; // Offset by 1 header to find the length until the next header
34+
while (b != end) {
35+
if (b + sizeof(NAL_HEADER) <= end && memcmp(b, NAL_HEADER, sizeof(NAL_HEADER)) == 0) {
36+
foundHeaders++;
37+
if (foundHeaders == nalNum) {
38+
break;
39+
}
40+
b += sizeof(NAL_HEADER);
41+
headersLen += sizeof(NAL_HEADER);
42+
}
43+
44+
b++;
45+
headersLen++;
46+
}
47+
if (foundHeaders != nalNum) {
48+
return;
49+
}
50+
InitializeDecoder((const unsigned char *)*buf, headersLen);
51+
52+
// move the cursor forward excluding config NALs
53+
*buf = b;
54+
*len -= headersLen;
55+
}
56+
57+
void processH264Nals(uint8_t **buf, int *len) {
58+
uint8_t *b = *buf;
59+
int l = *len;
60+
uint8_t nalType = b[4] & 0x1F;
61+
62+
if (nalType == H264_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 2) {
63+
b += sizeof(NAL_HEADER) + 2;
64+
l -= sizeof(NAL_HEADER) + 2;
65+
nalType = b[4] & 0x1F;
66+
}
67+
if (nalType == H264_NAL_TYPE_SPS) {
68+
sendHeaders(&b, &l, 2); // 2 headers SPS and PPS
69+
}
70+
*buf = b;
71+
*len = l;
72+
}
73+
74+
void processH265Nals(uint8_t **buf, int *len) {
75+
uint8_t *b = *buf;
76+
int l = *len;
77+
uint8_t nalType = (b[4] >> 1) & 0x3F;
78+
79+
if (nalType == H265_NAL_TYPE_AUD && l > sizeof(NAL_HEADER) * 2 + 3) {
80+
b += sizeof(NAL_HEADER) + 3;
81+
l -= sizeof(NAL_HEADER) + 3;
82+
nalType = (b[4] >> 1) & 0x3F;
83+
}
84+
if (nalType == H265_NAL_TYPE_VPS) {
85+
sendHeaders(&b, &l, 3); // 3 headers VPS, SPS and PPS
86+
}
87+
*buf = b;
88+
*len = l;
4189
}
4290

4391
void ClientConnection::SendVideo(uint8_t *buf, int len, uint64_t targetTimestampNs) {
4492
// Report before the frame is packetized
4593
ReportEncoded(targetTimestampNs);
4694

47-
uint8_t NALType;
48-
if (Settings::Instance().m_codec == ALVR_CODEC_H264)
49-
NALType = buf[4] & 0x1F;
50-
else
51-
NALType = (buf[4] >> 1) & 0x3F;
52-
53-
if ((Settings::Instance().m_codec == ALVR_CODEC_H264 && NALType == NAL_TYPE_SPS) ||
54-
(Settings::Instance().m_codec == ALVR_CODEC_H265 && NALType == H265_NAL_TYPE_VPS)) {
55-
// This frame contains (VPS + )SPS + PPS + IDR on NVENC H.264 (H.265) stream.
56-
// (VPS + )SPS + PPS has short size (8bytes + 28bytes in some environment), so we can
57-
// assume SPS + PPS is contained in first fragment.
58-
59-
int end = findVPSSPS(buf, len);
60-
if (end == -1) {
61-
// Invalid frame.
62-
return;
63-
}
64-
65-
InitializeDecoder((const unsigned char *)buf, end);
95+
if (len < sizeof(NAL_HEADER)) {
96+
return;
97+
}
6698

67-
// move the cursor forward excluding config NALs
68-
buf = &buf[end];
69-
len = len - end;
99+
int codec = Settings::Instance().m_codec;
100+
if (codec == ALVR_CODEC_H264) {
101+
processH264Nals(&buf, &len);
102+
} else if (codec == ALVR_CODEC_H265) {
103+
processH265Nals(&buf, &len);
70104
}
71105

72106
VideoSend(targetTimestampNs, buf, len);

alvr/server/cpp/platform/linux/CEncoder.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ void CEncoder::Run() {
224224

225225
fprintf(stderr, "CEncoder starting to read present packets");
226226
present_packet frame_info;
227-
std::vector<uint8_t> encoded_data;
228227
while (not m_exiting) {
229228
read_latest(client, (char *)&frame_info, sizeof(frame_info), m_exiting);
230229

@@ -250,9 +249,8 @@ void CEncoder::Run() {
250249

251250
static_assert(sizeof(frame_info.pose) == sizeof(vr::HmdMatrix34_t&));
252251

253-
encoded_data.clear();
254-
uint64_t pts;
255-
if (!encode_pipeline->GetEncoded(encoded_data, &pts)) {
252+
alvr::FramePacket packet;
253+
if (!encode_pipeline->GetEncoded(packet)) {
256254
Error("Failed to get encoded data!");
257255
continue;
258256
}
@@ -279,10 +277,9 @@ void CEncoder::Run() {
279277
ReportPresent(pose->targetTimestampNs, present_offset);
280278
ReportComposed(pose->targetTimestampNs, composed_offset);
281279

282-
m_listener->SendVideo(encoded_data.data(), encoded_data.size(), pts);
280+
m_listener->SendVideo(packet.data, packet.size, packet.pts);
283281

284282
m_listener->GetStatistics()->EncodeOutput();
285-
286283
}
287284
}
288285
catch (std::exception &e) {

alvr/server/cpp/platform/linux/EncodePipeline.cpp

Lines changed: 12 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -12,59 +12,6 @@ extern "C" {
1212
#include <libavcodec/avcodec.h>
1313
}
1414

15-
namespace {
16-
17-
bool should_keep_nal_h264(const uint8_t * header_start)
18-
{
19-
uint8_t nal_type = (header_start[2] == 0 ? header_start[4] : header_start[3]) & 0x1F;
20-
switch (nal_type)
21-
{
22-
case 6: // supplemental enhancement information
23-
case 9: // access unit delimiter
24-
return false;
25-
default:
26-
return true;
27-
}
28-
}
29-
30-
bool should_keep_nal_h265(const uint8_t * header_start)
31-
{
32-
uint8_t nal_type = ((header_start[2] == 0 ? header_start[4] : header_start[3]) >> 1) & 0x3F;
33-
switch (nal_type)
34-
{
35-
case 35: // access unit delimiter
36-
case 39: // supplemental enhancement information
37-
return false;
38-
default:
39-
return true;
40-
}
41-
}
42-
43-
void filter_NAL(const uint8_t* input, size_t input_size, std::vector<uint8_t> &out)
44-
{
45-
if (input_size < 4)
46-
return;
47-
auto codec = Settings::Instance().m_codec;
48-
std::array<uint8_t, 3> header = {{0, 0, 1}};
49-
auto end = input + input_size;
50-
auto header_start = input;
51-
while (header_start != end)
52-
{
53-
auto next_header = std::search(header_start + 3, end, header.begin(), header.end());
54-
if (next_header != end and next_header[-1] == 0)
55-
{
56-
next_header--;
57-
}
58-
if (codec == ALVR_CODEC_H264 and should_keep_nal_h264(header_start))
59-
out.insert(out.end(), header_start, next_header);
60-
if (codec == ALVR_CODEC_H265 and should_keep_nal_h265(header_start))
61-
out.insert(out.end(), header_start, next_header);
62-
header_start = next_header;
63-
}
64-
}
65-
66-
}
67-
6815
void alvr::EncodePipeline::SetBitrate(int64_t bitrate) {
6916
encoder_ctx->bit_rate = bitrate;
7017
encoder_ctx->rc_buffer_size = bitrate / Settings::Instance().m_refreshRate * 1.1;
@@ -112,17 +59,20 @@ alvr::EncodePipeline::~EncodePipeline()
11259
avcodec_free_context(&encoder_ctx);
11360
}
11461

115-
bool alvr::EncodePipeline::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
62+
bool alvr::EncodePipeline::GetEncoded(FramePacket &packet)
11663
{
117-
AVPacket * enc_pkt = av_packet_alloc();
118-
int err = avcodec_receive_packet(encoder_ctx, enc_pkt);
119-
if (err == AVERROR(EAGAIN)) {
120-
return false;
121-
} else if (err) {
64+
av_packet_free(&encoder_packet);
65+
encoder_packet = av_packet_alloc();
66+
int err = avcodec_receive_packet(encoder_ctx, encoder_packet);
67+
if (err != 0) {
68+
av_packet_free(&encoder_packet);
69+
if (err == AVERROR(EAGAIN)) {
70+
return false;
71+
}
12272
throw alvr::AvException("failed to encode", err);
12373
}
124-
filter_NAL(enc_pkt->data, enc_pkt->size, out);
125-
*pts = enc_pkt->pts;
126-
av_packet_free(&enc_pkt);
74+
packet.data = encoder_packet->data;
75+
packet.size = encoder_packet->size;
76+
packet.pts = encoder_packet->pts;
12777
return true;
12878
}

alvr/server/cpp/platform/linux/EncodePipeline.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <vector>
55

66
extern "C" struct AVCodecContext;
7+
extern "C" struct AVPacket;
78

89
class Renderer;
910

@@ -14,6 +15,12 @@ class VkFrame;
1415
class VkFrameCtx;
1516
class VkContext;
1617

18+
struct FramePacket {
19+
uint8_t *data;
20+
int size;
21+
uint64_t pts;
22+
};
23+
1724
class EncodePipeline
1825
{
1926
public:
@@ -25,13 +32,14 @@ class EncodePipeline
2532
virtual ~EncodePipeline();
2633

2734
virtual void PushFrame(uint64_t targetTimestampNs, bool idr) = 0;
28-
virtual bool GetEncoded(std::vector<uint8_t> & out, uint64_t *pts);
35+
virtual bool GetEncoded(FramePacket &data);
2936
virtual Timestamp GetTimestamp() { return timestamp; }
3037

3138
virtual void SetBitrate(int64_t bitrate);
3239
static std::unique_ptr<EncodePipeline> Create(Renderer *render, VkContext &vk_ctx, VkFrame &input_frame, VkFrameCtx &vk_frame_ctx, uint32_t width, uint32_t height);
3340
protected:
3441
AVCodecContext *encoder_ctx = nullptr; //shall be initialized by child class
42+
AVPacket *encoder_packet = NULL;
3543
Timestamp timestamp = {};
3644
};
3745

alvr/server/cpp/platform/linux/EncodePipelineAMF.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -456,26 +456,27 @@ void EncodePipelineAMF::PushFrame(uint64_t targetTimestampNs, bool idr)
456456
m_amfComponents.front()->SubmitInput(surface);
457457
}
458458

459-
bool EncodePipelineAMF::GetEncoded(std::vector<uint8_t> &out, uint64_t *pts)
459+
bool EncodePipelineAMF::GetEncoded(FramePacket &packet)
460460
{
461+
m_frameBuffer = NULL;
461462
if (m_hasQueryTimeout) {
462463
m_pipeline->Run();
463464
} else {
464465
uint32_t timeout = 4 * 1000; // 1 second
465-
while (m_outBuffer.empty() && --timeout != 0) {
466+
while (m_frameBuffer == NULL && --timeout != 0) {
466467
std::this_thread::sleep_for(std::chrono::microseconds(250));
467468
m_pipeline->Run();
468469
}
469470
}
470471

471-
if (m_outBuffer.empty()) {
472+
if (m_frameBuffer == NULL) {
472473
Error("Timed out waiting for encoder data");
473474
return false;
474475
}
475476

476-
out = m_outBuffer;
477-
*pts = m_targetTimestampNs;
478-
m_outBuffer.clear();
477+
packet.data = reinterpret_cast<uint8_t *>(m_frameBuffer->GetNative());
478+
packet.size = static_cast<int>(m_frameBuffer->GetSize());
479+
packet.pts = m_targetTimestampNs;
479480

480481
uint64_t query;
481482
VK_CHECK(vkGetQueryPoolResults(m_render->m_dev, m_queryPool, 0, 1, sizeof(uint64_t), &query, sizeof(uint64_t), VK_QUERY_RESULT_64_BIT));
@@ -499,12 +500,7 @@ void EncodePipelineAMF::SetBitrate(int64_t bitrate)
499500

500501
void EncodePipelineAMF::Receive(amf::AMFDataPtr data)
501502
{
502-
amf::AMFBufferPtr buffer(data); // query for buffer interface
503-
504-
char *p = reinterpret_cast<char*>(buffer->GetNative());
505-
int length = static_cast<int>(buffer->GetSize());
506-
507-
m_outBuffer = std::vector<uint8_t>(p, p + length);
503+
m_frameBuffer = amf::AMFBufferPtr(data); // query for buffer interface
508504
}
509505

510506
void EncodePipelineAMF::ApplyFrameProperties(const amf::AMFSurfacePtr &surface, bool insertIDR)

alvr/server/cpp/platform/linux/EncodePipelineAMF.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class EncodePipelineAMF : public EncodePipeline
6868
~EncodePipelineAMF();
6969

7070
void PushFrame(uint64_t targetTimestampNs, bool idr) override;
71-
bool GetEncoded(std::vector<uint8_t> &out, uint64_t *pts) override;
71+
bool GetEncoded(FramePacket &packet) override;
7272
void SetBitrate(int64_t bitrate) override;
7373

7474
private:
@@ -96,7 +96,7 @@ class EncodePipelineAMF : public EncodePipeline
9696
int m_bitrateInMBits;
9797

9898
bool m_hasQueryTimeout = false;
99-
std::vector<uint8_t> m_outBuffer;
99+
amf::AMFBufferPtr m_frameBuffer;
100100
uint64_t m_targetTimestampNs;
101101
};
102102

0 commit comments

Comments
 (0)