forked from shotover/shotover-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalkey_get_rewrite.rs
More file actions
123 lines (107 loc) · 3.32 KB
/
valkey_get_rewrite.rs
File metadata and controls
123 lines (107 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use anyhow::Result;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, MessageType, ValkeyFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{
ChainState, Transform, TransformBuilder, TransformConfig, TransformContextConfig,
};
use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol};
#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct ValkeyGetRewriteConfig {
pub name: String,
pub result: String,
}
const NAME: &str = "ValkeyGetRewrite";
#[typetag::serde(name = "ValkeyGetRewrite")]
#[async_trait(?Send)]
impl TransformConfig for ValkeyGetRewriteConfig {
fn get_name(&self) -> &str {
&self.name
}
async fn get_builder(
&self,
_transform_context: TransformContextConfig,
) -> Result<Box<dyn TransformBuilder>> {
Ok(Box::new(ValkeyGetRewriteBuilder {
result: self.result.clone(),
}))
}
fn up_chain_protocol(&self) -> UpChainProtocol {
UpChainProtocol::MustBeOneOf(vec![MessageType::Valkey])
}
fn down_chain_protocol(&self) -> DownChainProtocol {
DownChainProtocol::SameAsUpChain
}
fn get_sub_chain_configs(
&self,
) -> Vec<(&shotover::config::chain::TransformChainConfig, String)> {
vec![]
}
}
pub struct ValkeyGetRewriteBuilder {
result: String,
}
impl TransformBuilder for ValkeyGetRewriteBuilder {
fn build(&self, _transform_context: TransformContextBuilder) -> Box<dyn Transform> {
Box::new(ValkeyGetRewrite {
get_requests: MessageIdSet::default(),
result: self.result.clone(),
})
}
fn get_name(&self) -> &'static str {
NAME
}
}
pub struct ValkeyGetRewrite {
get_requests: MessageIdSet,
result: String,
}
#[async_trait]
impl Transform for ValkeyGetRewrite {
fn get_name(&self) -> &'static str {
NAME
}
async fn transform<'shorter, 'longer: 'shorter>(
&mut self,
chain_state: &'shorter mut ChainState<'longer>,
) -> Result<Messages> {
for message in chain_state.requests.iter_mut() {
if let Some(frame) = message.frame() {
if is_get(frame) {
self.get_requests.insert(message.id());
}
}
}
let mut responses = chain_state.call_next_transform().await?;
for response in responses.iter_mut() {
if response
.request_id()
.map(|id| self.get_requests.remove(&id))
.unwrap_or(false)
{
if let Some(frame) = response.frame() {
rewrite_get(frame, &self.result);
response.invalidate_cache();
}
}
}
Ok(responses)
}
}
fn is_get(frame: &Frame) -> bool {
if let Frame::Valkey(ValkeyFrame::Array(array)) = frame {
if let Some(ValkeyFrame::BulkString(first)) = array.first() {
first.eq_ignore_ascii_case(b"GET")
} else {
false
}
} else {
false
}
}
fn rewrite_get(frame: &mut Frame, result: &str) {
tracing::info!("Replaced {frame:?} with BulkString(\"{result}\")");
*frame = Frame::Valkey(ValkeyFrame::BulkString(result.to_owned().into()));
}