Skip to content

Commit 866eb9e

Browse files
martindukecopybara-github
authored andcommitted
Factor common Bidi Stream elements out of MoqtSession::ControlStream into MoqtBidiStream.
Create "UnknownBidiStream" to handle streams where the type is unknown. Allow unavailable streams and unsent messages to be queued. PiperOrigin-RevId: 862928045
1 parent 628c623 commit 866eb9e

14 files changed

+751
-221
lines changed

build/source_list.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,6 +1573,7 @@ load_balancer_srcs = [
15731573
"quic/load_balancer/load_balancer_server_id_test.cc",
15741574
]
15751575
moqt_hdrs = [
1576+
"quic/moqt/moqt_bidi_stream.h",
15761577
"quic/moqt/moqt_bitrate_adjuster.h",
15771578
"quic/moqt/moqt_error.h",
15781579
"quic/moqt/moqt_fetch_task.h",
@@ -1635,6 +1636,7 @@ moqt_srcs = [
16351636
moqt_test_hdrs = [
16361637
]
16371638
moqt_test_srcs = [
1639+
"quic/moqt/moqt_bidi_stream_test.cc",
16381640
"quic/moqt/moqt_bitrate_adjuster_test.cc",
16391641
"quic/moqt/moqt_framer_test.cc",
16401642
"quic/moqt/moqt_integration_test.cc",

build/source_list.gni

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,7 @@ load_balancer_srcs = [
15771577
"src/quiche/quic/load_balancer/load_balancer_server_id_test.cc",
15781578
]
15791579
moqt_hdrs = [
1580+
"src/quiche/quic/moqt/moqt_bidi_stream.h",
15801581
"src/quiche/quic/moqt/moqt_bitrate_adjuster.h",
15811582
"src/quiche/quic/moqt/moqt_error.h",
15821583
"src/quiche/quic/moqt/moqt_fetch_task.h",
@@ -1640,6 +1641,7 @@ moqt_test_hdrs = [
16401641

16411642
]
16421643
moqt_test_srcs = [
1644+
"src/quiche/quic/moqt/moqt_bidi_stream_test.cc",
16431645
"src/quiche/quic/moqt/moqt_bitrate_adjuster_test.cc",
16441646
"src/quiche/quic/moqt/moqt_framer_test.cc",
16451647
"src/quiche/quic/moqt/moqt_integration_test.cc",

build/source_list.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,6 +1576,7 @@
15761576
"quiche/quic/load_balancer/load_balancer_server_id_test.cc"
15771577
],
15781578
"moqt_hdrs": [
1579+
"quiche/quic/moqt/moqt_bidi_stream.h",
15791580
"quiche/quic/moqt/moqt_bitrate_adjuster.h",
15801581
"quiche/quic/moqt/moqt_error.h",
15811582
"quiche/quic/moqt/moqt_fetch_task.h",
@@ -1639,6 +1640,7 @@
16391640

16401641
],
16411642
"moqt_test_srcs": [
1643+
"quiche/quic/moqt/moqt_bidi_stream_test.cc",
16421644
"quiche/quic/moqt/moqt_bitrate_adjuster_test.cc",
16431645
"quiche/quic/moqt/moqt_framer_test.cc",
16441646
"quiche/quic/moqt/moqt_integration_test.cc",
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
// Copyright (c) 2026 The Chromium Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
#ifndef QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H
6+
#define QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H
7+
8+
#include <array>
9+
#include <cstddef>
10+
#include <cstdint>
11+
#include <memory>
12+
#include <queue>
13+
#include <utility>
14+
15+
#include "absl/base/nullability.h"
16+
#include "absl/status/status.h"
17+
#include "absl/strings/string_view.h"
18+
#include "absl/types/span.h"
19+
#include "quiche/quic/moqt/moqt_error.h"
20+
#include "quiche/quic/moqt/moqt_framer.h"
21+
#include "quiche/quic/moqt/moqt_key_value_pair.h"
22+
#include "quiche/quic/moqt/moqt_messages.h"
23+
#include "quiche/quic/moqt/moqt_parser.h"
24+
#include "quiche/common/platform/api/quiche_bug_tracker.h"
25+
#include "quiche/common/quiche_buffer_allocator.h"
26+
#include "quiche/common/quiche_callbacks.h"
27+
#include "quiche/common/quiche_mem_slice.h"
28+
#include "quiche/common/quiche_stream.h"
29+
#include "quiche/web_transport/web_transport.h"
30+
31+
namespace moqt {
32+
33+
enum class MoqtBidiStreamType : uint8_t {
34+
kUnknown,
35+
kControl,
36+
kSubscribeNamespace, // TODO(martinduke): Support this case.
37+
};
38+
39+
using SessionErrorCallback =
40+
quiche::SingleUseCallback<void(MoqtError, absl::string_view)>;
41+
// The provider of this callback owns nothing in MoqtBidiStreamBase. This merely
42+
// deletes the record.
43+
using BidiStreamDeletedCallback = quiche::SingleUseCallback<void()>;
44+
45+
// A generic parser visitor that assumes all messages are invalid. Serves a base
46+
// class for visitors that accept a subset of messages and maintains state based
47+
// on those messages.
48+
class MoqtBidiStreamBase : public MoqtControlParserVisitor,
49+
public webtransport::StreamVisitor {
50+
public:
51+
MoqtBidiStreamBase(MoqtFramer* absl_nonnull framer,
52+
BidiStreamDeletedCallback stream_deleted_callback,
53+
SessionErrorCallback session_error_callback)
54+
: framer_(framer),
55+
stream_deleted_callback_(std::move(stream_deleted_callback)),
56+
session_error_callback_(std::move(session_error_callback)) {}
57+
~MoqtBidiStreamBase() override { std::move(stream_deleted_callback_)(); }
58+
// The caller is responsible for calling stream->SetVisitor(). Derived
59+
// classes will wrap this with a call to stream->SetPriority().
60+
virtual void set_stream(webtransport::Stream* absl_nonnull stream) {
61+
stream_ = stream;
62+
parser_ = std::make_unique<MoqtControlParser>(framer_->using_webtrans(),
63+
stream_, *this);
64+
}
65+
66+
// MoqtControlParserVisitor implementation. All control messages are protocol
67+
// violations by default.
68+
virtual void OnClientSetupMessage(const MoqtClientSetup& message) override {
69+
OnParsingError(wrong_message_error_, wrong_message_reason_);
70+
}
71+
virtual void OnServerSetupMessage(const MoqtServerSetup& message) override {
72+
OnParsingError(wrong_message_error_, wrong_message_reason_);
73+
}
74+
virtual void OnRequestOkMessage(const MoqtRequestOk& message) override {
75+
OnParsingError(wrong_message_error_, wrong_message_reason_);
76+
}
77+
virtual void OnRequestErrorMessage(const MoqtRequestError& message) override {
78+
OnParsingError(wrong_message_error_, wrong_message_reason_);
79+
}
80+
virtual void OnSubscribeMessage(const MoqtSubscribe& message) override {
81+
OnParsingError(wrong_message_error_, wrong_message_reason_);
82+
}
83+
virtual void OnSubscribeOkMessage(const MoqtSubscribeOk& message) override {
84+
OnParsingError(wrong_message_error_, wrong_message_reason_);
85+
}
86+
virtual void OnUnsubscribeMessage(const MoqtUnsubscribe& message) override {
87+
OnParsingError(wrong_message_error_, wrong_message_reason_);
88+
}
89+
virtual void OnPublishDoneMessage(const MoqtPublishDone& message) override {
90+
OnParsingError(wrong_message_error_, wrong_message_reason_);
91+
}
92+
virtual void OnSubscribeUpdateMessage(
93+
const MoqtSubscribeUpdate& message) override {
94+
OnParsingError(wrong_message_error_, wrong_message_reason_);
95+
}
96+
virtual void OnPublishNamespaceMessage(
97+
const MoqtPublishNamespace& message) override {
98+
OnParsingError(wrong_message_error_, wrong_message_reason_);
99+
}
100+
virtual void OnPublishNamespaceDoneMessage(
101+
const MoqtPublishNamespaceDone& message) override {
102+
OnParsingError(wrong_message_error_, wrong_message_reason_);
103+
}
104+
virtual void OnPublishNamespaceCancelMessage(
105+
const MoqtPublishNamespaceCancel& message) override {
106+
OnParsingError(wrong_message_error_, wrong_message_reason_);
107+
}
108+
virtual void OnTrackStatusMessage(const MoqtTrackStatus& message) override {
109+
OnParsingError(wrong_message_error_, wrong_message_reason_);
110+
}
111+
virtual void OnGoAwayMessage(const MoqtGoAway& message) override {
112+
OnParsingError(wrong_message_error_, wrong_message_reason_);
113+
}
114+
virtual void OnSubscribeNamespaceMessage(
115+
const MoqtSubscribeNamespace& message) override {
116+
OnParsingError(wrong_message_error_, wrong_message_reason_);
117+
}
118+
virtual void OnUnsubscribeNamespaceMessage(
119+
const MoqtUnsubscribeNamespace& message) override {
120+
OnParsingError(wrong_message_error_, wrong_message_reason_);
121+
}
122+
virtual void OnMaxRequestIdMessage(const MoqtMaxRequestId& message) override {
123+
OnParsingError(wrong_message_error_, wrong_message_reason_);
124+
}
125+
virtual void OnFetchMessage(const MoqtFetch& message) override {
126+
OnParsingError(wrong_message_error_, wrong_message_reason_);
127+
}
128+
virtual void OnFetchCancelMessage(const MoqtFetchCancel& message) override {
129+
OnParsingError(wrong_message_error_, wrong_message_reason_);
130+
}
131+
virtual void OnFetchOkMessage(const MoqtFetchOk& message) override {
132+
OnParsingError(wrong_message_error_, wrong_message_reason_);
133+
}
134+
virtual void OnRequestsBlockedMessage(
135+
const MoqtRequestsBlocked& message) override {
136+
OnParsingError(wrong_message_error_, wrong_message_reason_);
137+
}
138+
virtual void OnPublishMessage(const MoqtPublish& message) override {
139+
OnParsingError(wrong_message_error_, wrong_message_reason_);
140+
}
141+
virtual void OnPublishOkMessage(const MoqtPublishOk& message) override {
142+
OnParsingError(wrong_message_error_, wrong_message_reason_);
143+
}
144+
virtual void OnObjectAckMessage(const MoqtObjectAck& message) override {
145+
OnParsingError(wrong_message_error_, wrong_message_reason_);
146+
}
147+
virtual void OnParsingError(MoqtError code,
148+
absl::string_view reason) override {
149+
std::move(session_error_callback_)(code, reason);
150+
}
151+
152+
// webtransport::StreamVisitor implementation.
153+
void OnResetStreamReceived(webtransport::StreamErrorCode error) override {}
154+
void OnStopSendingReceived(webtransport::StreamErrorCode error) override {}
155+
void OnWriteSideInDataRecvdState() override {}
156+
void OnCanRead() override {
157+
if (parser_ == nullptr) {
158+
QUICHE_BUG(quiche_bug_moqt_parser_is_null) << "Parser is null";
159+
return;
160+
}
161+
parser_->ReadAndDispatchMessages();
162+
}
163+
void OnCanWrite() override {
164+
if (pending_messages_.empty() && fin_queued_) {
165+
if (!stream_->SendFin()) {
166+
std::move(session_error_callback_)(MoqtError::kInternalError,
167+
"Failed to send FIN");
168+
}
169+
return;
170+
}
171+
while (!pending_messages_.empty() && stream_->CanWrite()) {
172+
SendMessage(std::move(pending_messages_.front()),
173+
fin_queued_ && pending_messages_.size() == 1);
174+
pending_messages_.pop();
175+
}
176+
}
177+
178+
void SendOrBufferMessage(quiche::QuicheBuffer message, bool fin = false) {
179+
if (fin_queued_) {
180+
return;
181+
}
182+
if (stream_ == nullptr || !stream_->CanWrite()) {
183+
AddToQueue(std::move(message));
184+
return;
185+
}
186+
SendMessage(std::move(message), fin);
187+
}
188+
void SendRequestOk(uint64_t request_id,
189+
const VersionSpecificParameters& parameters,
190+
bool fin = false) {
191+
SendOrBufferMessage(
192+
framer_->SerializeRequestOk(MoqtRequestOk{request_id, parameters}),
193+
fin);
194+
}
195+
void SendRequestError(uint64_t request_id, RequestErrorCode error_code,
196+
absl::string_view reason_phrase, bool fin = false) {
197+
MoqtRequestError request_error;
198+
request_error.request_id = request_id;
199+
request_error.error_code = error_code;
200+
request_error.reason_phrase = reason_phrase;
201+
SendOrBufferMessage(framer_->SerializeRequestError(request_error), fin);
202+
}
203+
void Fin() {
204+
fin_queued_ = true;
205+
if (pending_messages_.empty()) {
206+
if (!stream_->SendFin()) {
207+
std::move(session_error_callback_)(MoqtError::kInternalError,
208+
"Failed to send FIN");
209+
}
210+
return;
211+
}
212+
}
213+
void Reset(webtransport::StreamErrorCode error) {
214+
if (stream_ != nullptr) {
215+
stream_->ResetWithUserCode(error);
216+
}
217+
}
218+
219+
protected:
220+
const size_t kMaxPendingMessages = 100;
221+
void AddToQueue(quiche::QuicheBuffer message) {
222+
if (pending_messages_.size() == kMaxPendingMessages) {
223+
std::move(session_error_callback_)(
224+
MoqtError::kInternalError,
225+
"Not enough flow credit on the control stream");
226+
return;
227+
}
228+
pending_messages_.push(std::move(message));
229+
}
230+
MoqtFramer* absl_nonnull framer_;
231+
MoqtControlParser* parser() { return parser_.get(); }
232+
void OnBidiStreamDeleted() {
233+
if (stream_deleted_callback_ != nullptr) {
234+
std::move(stream_deleted_callback_)();
235+
}
236+
}
237+
webtransport::Stream* stream() { return stream_; }
238+
239+
private:
240+
void SendMessage(quiche::QuicheBuffer message, bool fin) {
241+
quiche::StreamWriteOptions options;
242+
options.set_send_fin(fin);
243+
// TODO: while we buffer unconditionally, we should still at some point tear
244+
// down the connection if we've buffered too many control messages;
245+
// otherwise, there is potential for memory exhaustion attacks.
246+
options.set_buffer_unconditionally(true);
247+
std::array write_vector = {quiche::QuicheMemSlice(std::move(message))};
248+
absl::Status success =
249+
stream_->Writev(absl::MakeSpan(write_vector), options);
250+
if (!success.ok()) {
251+
std::move(session_error_callback_)(MoqtError::kInternalError,
252+
"Failed to write a control message");
253+
}
254+
}
255+
256+
webtransport::Stream* stream_;
257+
std::unique_ptr<MoqtControlParser> parser_;
258+
std::queue<quiche::QuicheBuffer> pending_messages_;
259+
bool fin_queued_ = false;
260+
BidiStreamDeletedCallback stream_deleted_callback_;
261+
SessionErrorCallback session_error_callback_;
262+
const MoqtError wrong_message_error_ = MoqtError::kProtocolViolation;
263+
const absl::string_view wrong_message_reason_ =
264+
"Message not allowed for this stream type";
265+
};
266+
267+
} // namespace moqt
268+
269+
#endif // QUICHE_QUIC_MOQT_MOQT_BIDI_STREAM_H

0 commit comments

Comments
 (0)