Skip to content

Commit 0427f0c

Browse files
Vignesh2208copybara-github
authored andcommitted
[Filter Fusion] Create method and Fused Filter TypeName Implementation
PiperOrigin-RevId: 749157291
1 parent a63bcae commit 0427f0c

File tree

5 files changed

+211
-41
lines changed

5 files changed

+211
-41
lines changed

src/core/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9178,12 +9178,17 @@ grpc_cc_library(
91789178
],
91799179
external_deps = [
91809180
"absl/status",
9181+
"absl/strings",
9182+
"absl/log:check",
91819183
"absl/log",
9184+
"absl/memory",
91829185
],
91839186
deps = [
91849187
"call_filters",
91859188
"call_final_info",
9189+
"channel_args",
91869190
"metadata",
9191+
"status_helper",
91879192
"type_list",
91889193
"//:grpc_base",
91899194
"//:grpc_public_hdrs",

src/core/call/filter_fusion.h

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,25 @@
1616
#define GRPC_SRC_CORE_CALL_FILTER_FUSION_H
1717
#include <grpc/impl/grpc_types.h>
1818

19+
#include <cstddef>
20+
#include <memory>
21+
#include <string>
22+
#include <tuple>
1923
#include <type_traits>
2024
#include <utility>
2125

26+
#include "absl/log/check.h"
2227
#include "absl/log/log.h"
28+
#include "absl/memory/memory.h"
2329
#include "absl/status/status.h"
30+
#include "absl/strings/str_join.h"
2431
#include "src/core/call/call_filters.h"
2532
#include "src/core/call/metadata.h"
33+
#include "src/core/lib/channel/channel_args.h"
2634
#include "src/core/lib/channel/promise_based_filter.h"
2735
#include "src/core/lib/transport/call_final_info.h"
36+
#include "src/core/lib/transport/transport.h"
37+
#include "src/core/util/status_helper.h"
2838
#include "src/core/util/type_list.h"
2939

3040
struct grpc_transport_op;
@@ -893,6 +903,58 @@ GRPC_FUSE_METHOD(OnClientToServerMessage, MessageHandle, true);
893903
GRPC_FUSE_METHOD(OnServerToClientMessage, MessageHandle, false);
894904
GRPC_FUSE_METHOD(OnFinalize, const grpc_call_final_info*, true);
895905

906+
template <typename... Filters>
907+
struct FilterWrapper;
908+
909+
template <typename Filter>
910+
struct FilterWrapper<Filter> {
911+
FilterWrapper(const ChannelArgs& args, ChannelFilter::Args filter_args)
912+
: filter_(Filter::Create(args, filter_args)) {}
913+
914+
absl::Status status() const { return filter_.status(); }
915+
bool StartTransportOp(grpc_transport_op* op) {
916+
CHECK(filter_.ok());
917+
return (*filter_)->StartTransportOp(op);
918+
}
919+
920+
bool GetChannelInfo(const grpc_channel_info* info) {
921+
CHECK(filter_.ok());
922+
return (*filter_)->GetChannelInfo(info);
923+
}
924+
925+
private:
926+
absl::StatusOr<std::unique_ptr<Filter>> filter_;
927+
};
928+
929+
template <typename Filter0, typename... Filters>
930+
struct FilterWrapper<Filter0, Filters...> : public FilterWrapper<Filters...> {
931+
FilterWrapper(const ChannelArgs& args, ChannelFilter::Args filter_args)
932+
: filter0_(Filter0::Create(args, filter_args)),
933+
FilterWrapper<Filters...>(args, filter_args) {}
934+
935+
absl::Status status() const {
936+
if (filter0_.ok()) {
937+
return FilterWrapper<Filters...>::status();
938+
}
939+
return filter0_.status();
940+
}
941+
942+
bool StartTransportOp(grpc_transport_op* op) {
943+
CHECK(filter0_.ok());
944+
return (*filter0_)->StartTransportOp(op) ||
945+
FilterWrapper<Filters...>::StartTransportOp(op);
946+
}
947+
948+
bool GetChannelInfo(const grpc_channel_info* info) {
949+
CHECK(filter0_.ok());
950+
return (*filter0_)->GetChannelInfo(info) ||
951+
FilterWrapper<Filters...>::GetChannelInfo(info);
952+
}
953+
954+
private:
955+
absl::StatusOr<std::unique_ptr<Filter0>> filter0_;
956+
};
957+
896958
#undef GRPC_FUSE_METHOD
897959

898960
template <typename... Filters>
@@ -902,19 +964,29 @@ class FusedFilter : public ImplementChannelFilter<FusedFilter<Filters...>>,
902964
static const grpc_channel_filter kFilter;
903965

904966
static absl::string_view TypeName() {
905-
// TODO(vigneshbabu): - Concatenate the names of the constituent filters.
906-
return "fused_filter";
967+
static const std::string kName = absl::StrCat(
968+
"Fused_Filter_",
969+
absl::StrJoin(std::make_tuple(Filters::TypeName()...), "_"));
970+
return kName;
907971
}
908972

909973
static absl::StatusOr<std::unique_ptr<FusedFilter<Filters...>>> Create(
910-
const ChannelArgs& args, ChannelFilter::Args) {
911-
// TODO(vigneshbabu): - Implement this.
912-
LOG(FATAL) << "Not implemented";
913-
return nullptr;
974+
const ChannelArgs& args, ChannelFilter::Args filter_args) {
975+
auto filters_wrapper =
976+
std::make_unique<FilterWrapper<Filters...>>(args, filter_args);
977+
GRPC_RETURN_IF_ERROR(filters_wrapper->status());
978+
return absl::WrapUnique<FusedFilter<Filters...>>(
979+
new FusedFilter<Filters...>(std::move(filters_wrapper)));
914980
}
915981

916982
static constexpr bool IsFused = true;
917983

984+
static constexpr bool FusedFilterHasAsyncErrorInterceptor() {
985+
return (
986+
promise_filter_detail::CallHasAsyncErrorInterceptor<Filters>::value ||
987+
...);
988+
}
989+
918990
class Call : public FuseOnClientInitialMetadata<FusedFilter, Filters...>,
919991
public FuseOnServerInitialMetadata<FusedFilter, Filters...>,
920992
public FuseOnClientToServerMessage<FusedFilter, Filters...>,
@@ -947,27 +1019,18 @@ class FusedFilter : public ImplementChannelFilter<FusedFilter<Filters...>>,
9471019
};
9481020

9491021
bool StartTransportOp(grpc_transport_op* op) override {
950-
return StartTransportOpInternal(op, Typelist<Filters...>());
1022+
return filters_->StartTransportOp(op);
9511023
}
9521024

9531025
bool GetChannelInfo(const grpc_channel_info* info) override {
954-
return GetChannelInfoInternal(info, Typelist<Filters...>());
1026+
return filters_->GetChannelInfo(info);
9551027
}
9561028

9571029
private:
958-
template <typename... FilterTypes>
959-
bool StartTransportOpInternal(grpc_transport_op* op,
960-
Typelist<FilterTypes...>) {
961-
return (std::get<FilterTypes>(filters_).StartTransportOp(op) || ...);
962-
}
963-
964-
template <typename... FilterTypes>
965-
bool GetChannelInfoInternal(const grpc_channel_info* info,
966-
Typelist<FilterTypes...>) {
967-
return (std::get<FilterTypes>(filters_).GetChannelInfo(info) || ...);
968-
}
1030+
explicit FusedFilter(std::unique_ptr<FilterWrapper<Filters...>> filters)
1031+
: filters_(std::move(filters)) {};
9691032

970-
std::tuple<Filters...> filters_;
1033+
std::unique_ptr<FilterWrapper<Filters...>> filters_;
9711034
};
9721035

9731036
} // namespace filters_detail

src/core/lib/channel/promise_based_filter.h

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,37 @@ inline constexpr bool HasAnyAsyncErrorInterceptor(Interceptors...) {
167167
return (HasAsyncErrorInterceptor<Interceptors>::value || ...);
168168
}
169169

170+
// value is true if Derived has a member called IsFused.
171+
template <typename Derived>
172+
struct IsFusedFilter {
173+
template <typename V>
174+
static std::true_type test(decltype(&V::IsFused)); // SFINAE context
175+
template <typename V>
176+
static std::false_type test(...);
177+
178+
using type = decltype(test<Derived>(nullptr));
179+
static constexpr bool value = std::is_same_v<type, std::true_type>;
180+
};
181+
182+
template <typename Derived, typename Ignored = void>
183+
struct CallHasAsyncErrorInterceptor;
184+
170185
// Composite for a given channel type to determine if any of its interceptors
171186
// fall into this category: later code should use this.
172187
template <typename Derived>
173-
inline constexpr bool CallHasAsyncErrorInterceptor() {
174-
return HasAnyAsyncErrorInterceptor(&Derived::Call::OnClientToServerMessage,
175-
&Derived::Call::OnServerInitialMetadata,
176-
&Derived::Call::OnServerToClientMessage);
177-
}
188+
struct CallHasAsyncErrorInterceptor<
189+
Derived, std::enable_if_t<!IsFusedFilter<Derived>::value>> {
190+
static constexpr bool value =
191+
HasAnyAsyncErrorInterceptor(&Derived::Call::OnClientToServerMessage,
192+
&Derived::Call::OnServerInitialMetadata,
193+
&Derived::Call::OnServerToClientMessage);
194+
};
195+
196+
template <typename Derived>
197+
struct CallHasAsyncErrorInterceptor<
198+
Derived, std::enable_if_t<IsFusedFilter<Derived>::value>> {
199+
static constexpr bool value = Derived::FusedFilterHasAsyncErrorInterceptor();
200+
};
178201

179202
// Given a boolean X export a type:
180203
// either T if X is true
@@ -246,7 +269,7 @@ struct FilterCallData {
246269
GPR_NO_UNIQUE_ADDRESS CallWrapper<Derived> call;
247270
GPR_NO_UNIQUE_ADDRESS
248271
typename TypeIfNeeded<Latch<ServerMetadataHandle>,
249-
CallHasAsyncErrorInterceptor<Derived>()>::Type
272+
CallHasAsyncErrorInterceptor<Derived>::value>::Type
250273
error_latch;
251274
GPR_NO_UNIQUE_ADDRESS
252275
typename TypeIfNeeded<
@@ -317,7 +340,7 @@ auto MapResult(void (Derived::Call::*fn)(ServerMetadata&, Derived*), Promise x,
317340
// For fused filters whose OnServerTrailingMetadata takes pointer to the
318341
// channel.
319342
template <typename P, typename Call, typename Derived,
320-
typename = std::enable_if_t<Derived::IsFused>>
343+
typename = std::enable_if_t<IsFusedFilter<Derived>::value>>
321344
auto MapResult(void (Call::*fn)(ServerMetadata&, Derived*), P x,
322345
FilterCallData<Derived>* call_data) {
323346
DCHECK(fn == &Derived::Call::OnServerTrailingMetadata);
@@ -338,7 +361,7 @@ auto MapResult(void (Call::*fn)(ServerMetadata&, Derived*), P x,
338361
// For fused filters whose OnServerTrailingMetadata does not take pointer to the
339362
// channel.
340363
template <typename P, typename Call, typename Derived,
341-
typename = std::enable_if_t<Derived::IsFused>>
364+
typename = std::enable_if_t<IsFusedFilter<Derived>::value>>
342365
auto MapResult(void (Call::*fn)(ServerMetadata&), P x,
343366
FilterCallData<Derived>* call_data) {
344367
DCHECK(fn == &Derived::Call::OnServerTrailingMetadata);
@@ -1183,7 +1206,8 @@ class ImplementChannelFilter : public ChannelFilter,
11831206
return promise_filter_detail::MapResult(
11841207
&Derived::Call::OnServerTrailingMetadata,
11851208
promise_filter_detail::RaceAsyncCompletion<
1186-
promise_filter_detail::CallHasAsyncErrorInterceptor<Derived>()>::
1209+
promise_filter_detail::CallHasAsyncErrorInterceptor<
1210+
Derived>::value>::
11871211
Run(promise_filter_detail::RunCall(
11881212
&Derived::Call::OnClientInitialMetadata,
11891213
std::move(call_args), std::move(next_promise_factory),

0 commit comments

Comments
 (0)