Skip to content

Commit b45dbf4

Browse files
conn_id placeholder. client based auth
1 parent f3545b0 commit b45dbf4

File tree

10 files changed

+124
-49
lines changed

10 files changed

+124
-49
lines changed

client.cpp

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@
4040
#include <limits.h>
4141
#endif
4242

43-
#ifdef HAVE_ASSERT_H
44-
#include <assert.h>
45-
#endif
4643

44+
#include <assert.h>
45+
#include <limits.h>
4746
#include <math.h>
4847
#include <algorithm>
4948
#include <arpa/inet.h>
@@ -59,7 +58,14 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol,
5958
unsigned long long total_num_of_clients = config->clients*config->threads;
6059

6160
// create main connection
62-
shard_connection* conn = new shard_connection(m_connections.size(), this, m_config, m_event_base, protocol);
61+
unsigned int thread_id = 0; // TODO: set actual thread id if available
62+
unsigned int client_index = m_connections.size();
63+
unsigned int num_clients_per_thread = config->clients;
64+
unsigned int conn_id = thread_id * num_clients_per_thread + client_index;
65+
shard_connection* conn = new shard_connection(
66+
client_index, this, m_config, m_event_base, protocol,
67+
conn_id
68+
);
6369
m_connections.push_back(conn);
6470

6571
m_obj_gen = objgen->clone();
@@ -99,7 +105,7 @@ bool client::setup_client(benchmark_config *config, abstract_protocol *protocol,
99105
return true;
100106
}
101107

102-
client::client(client_group* group) :
108+
client::client(client_group* group, unsigned int conn_id) :
103109
m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL),
104110
m_obj_gen(NULL), m_stats(group->get_config()), m_reqs_processed(0), m_reqs_generated(0),
105111
m_set_ratio_count(0), m_get_ratio_count(0),
@@ -108,16 +114,21 @@ client::client(client_group* group) :
108114
{
109115
m_event_base = group->get_event_base();
110116

117+
// Initialize conn_id string and value with prefix
118+
m_conn_id_str = "user" + std::to_string(conn_id);
119+
m_conn_id_value = m_conn_id_str.c_str();
120+
m_conn_id_value_len = m_conn_id_str.length();
121+
111122
if (!setup_client(group->get_config(), group->get_protocol(), group->get_obj_gen())) {
112123
return;
113124
}
114125

115-
benchmark_debug_log("new client %p successfully set up.\n", this);
126+
benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value);
116127
m_initialized = true;
117128
}
118129

119130
client::client(struct event_base *event_base, benchmark_config *config,
120-
abstract_protocol *protocol, object_generator *obj_gen) :
131+
abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id) :
121132
m_event_base(NULL), m_initialized(false), m_end_set(false), m_config(NULL),
122133
m_obj_gen(NULL), m_stats(config), m_reqs_processed(0), m_reqs_generated(0),
123134
m_set_ratio_count(0), m_get_ratio_count(0),
@@ -126,11 +137,16 @@ client::client(struct event_base *event_base, benchmark_config *config,
126137
{
127138
m_event_base = event_base;
128139

140+
// Initialize conn_id string and value
141+
m_conn_id_str = std::to_string(conn_id);
142+
m_conn_id_value = m_conn_id_str.c_str();
143+
m_conn_id_value_len = m_conn_id_str.length();
144+
129145
if (!setup_client(config, protocol, obj_gen)) {
130146
return;
131147
}
132148

133-
benchmark_debug_log("new client %p successfully set up.\n", this);
149+
benchmark_debug_log("new client %p successfully set up with conn_id: %s.\n", this, m_conn_id_value);
134150
m_initialized = true;
135151
}
136152

@@ -273,7 +289,11 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval
273289

274290
const arbitrary_command& cmd = get_arbitrary_command(command_index);
275291

276-
benchmark_debug_log("%s: %s:\n", m_connections[conn_id]->get_readable_id(), cmd.command.c_str());
292+
benchmark_debug_log("%s: %s", m_connections[conn_id]->get_readable_id(), cmd.command.c_str());
293+
294+
// Build final command string for debug output
295+
std::string final_command = cmd.command;
296+
bool has_substitutions = false;
277297

278298
for (unsigned int i = 0; i < cmd.command_args.size(); i++) {
279299
const command_arg* arg = &cmd.command_args[i];
@@ -293,9 +313,32 @@ bool client::create_arbitrary_request(unsigned int command_index, struct timeval
293313
assert(value_len > 0);
294314

295315
cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, value, value_len);
316+
} else if (arg->type == conn_id_type) {
317+
// Replace __conn_id__ placeholder with actual connection ID
318+
std::string substituted_arg = arg->data;
319+
size_t pos = substituted_arg.find(CONN_PLACEHOLDER);
320+
if (pos != std::string::npos) {
321+
substituted_arg.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value);
322+
has_substitutions = true;
323+
}
324+
325+
cmd_size += m_connections[conn_id]->send_arbitrary_command(arg, substituted_arg.c_str(), substituted_arg.length());
326+
327+
// Replace placeholder in final command string for debug output
328+
pos = final_command.find(CONN_PLACEHOLDER);
329+
if (pos != std::string::npos) {
330+
final_command.replace(pos, strlen(CONN_PLACEHOLDER), m_conn_id_value);
331+
}
296332
}
297333
}
298334

335+
// Show final command if substitutions were made
336+
if (has_substitutions) {
337+
benchmark_debug_log(" -> %s\n", final_command.c_str());
338+
} else {
339+
benchmark_debug_log("\n");
340+
}
341+
299342
m_connections[conn_id]->send_arbitrary_command_end(command_index, &timestamp, cmd_size);
300343
return true;
301344
}
@@ -581,8 +624,8 @@ bool verify_client::finished(void)
581624

582625
///////////////////////////////////////////////////////////////////////////
583626

584-
client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen) :
585-
m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen)
627+
client_group::client_group(benchmark_config* config, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id) :
628+
m_base(NULL), m_config(config), m_protocol(protocol), m_obj_gen(obj_gen), m_thread_id(thread_id)
586629
{
587630
m_base = event_base_new();
588631
assert(m_base != NULL);
@@ -608,11 +651,12 @@ int client_group::create_clients(int num)
608651
{
609652
for (int i = 0; i < num; i++) {
610653
client* c;
654+
unsigned int conn_id = m_thread_id * num + i + 1;
611655

612656
if (m_config->cluster_mode)
613-
c = new cluster_client(this);
657+
c = new cluster_client(this, conn_id);
614658
else
615-
c = new client(this);
659+
c = new client(this, conn_id);
616660

617661
assert(c != NULL);
618662

client.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ class client : public connections_manager {
6363
// test related
6464
benchmark_config* m_config;
6565
object_generator* m_obj_gen;
66+
std::string m_conn_id_str;
67+
const char* m_conn_id_value;
68+
unsigned int m_conn_id_value_len;
6669
run_stats m_stats;
6770

6871
unsigned long long m_reqs_processed; // requests processed (responses received)
@@ -78,13 +81,14 @@ class client : public connections_manager {
7881
keylist *m_keylist; // used to construct multi commands
7982

8083
public:
81-
client(client_group* group);
82-
client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen);
84+
client(client_group* group, unsigned int conn_id = 0);
85+
client(struct event_base *event_base, benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen, unsigned int conn_id = 0);
8386
virtual ~client();
8487
bool setup_client(benchmark_config *config, abstract_protocol *protocol, object_generator *obj_gen);
8588
int prepare(void);
8689
bool initialized(void);
8790
run_stats* get_stats(void) { return &m_stats; }
91+
const char* get_conn_id_value(void) { return m_conn_id_value; }
8892

8993
virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index);
9094
virtual bool create_arbitrary_request(unsigned int command_index, struct timeval& timestamp, unsigned int conn_id);
@@ -203,8 +207,10 @@ class client_group {
203207
abstract_protocol* m_protocol;
204208
object_generator* m_obj_gen;
205209
std::vector<client*> m_clients;
210+
protected:
211+
unsigned int m_thread_id;
206212
public:
207-
client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen);
213+
client_group(benchmark_config *cfg, abstract_protocol *protocol, object_generator* obj_gen, unsigned int thread_id);
208214
~client_group();
209215

210216
int create_clients(int count);

cluster_client.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <assert.h>
12
/*
23
* Copyright (C) 2011-2017 Redis Labs Ltd.
34
*
@@ -108,7 +109,7 @@ static uint32_t calc_hslot_crc16_cluster(const char *str, size_t length)
108109

109110
///////////////////////////////////////////////////////////////////////////////////////////////////////
110111

111-
cluster_client::cluster_client(client_group* group) : client(group)
112+
cluster_client::cluster_client(client_group* group, unsigned int conn_id) : client(group, conn_id)
112113
{
113114
}
114115

@@ -159,9 +160,11 @@ void cluster_client::disconnect(void)
159160
}
160161

161162
shard_connection* cluster_client::create_shard_connection(abstract_protocol* abs_protocol) {
162-
shard_connection* sc = new shard_connection(m_connections.size(), this,
163-
m_config, m_event_base,
164-
abs_protocol);
163+
unsigned int conn_id = m_connections.size();
164+
shard_connection* sc = new shard_connection(
165+
conn_id, this, m_config, m_event_base, abs_protocol,
166+
conn_id
167+
);
165168
assert(sc != NULL);
166169

167170
m_connections.push_back(sc);

cluster_client.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ class cluster_client : public client {
4343
request *request, protocol_response *response);
4444

4545
public:
46-
cluster_client(client_group* group);
46+
cluster_client(client_group* group, unsigned int conn_id);
4747
virtual ~cluster_client();
4848

4949
virtual get_key_response get_key_for_conn(unsigned int command_index, unsigned int conn_id, unsigned long long* key_index);

config_types.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,14 @@ struct server_addr {
105105

106106
#define KEY_PLACEHOLDER "__key__"
107107
#define DATA_PLACEHOLDER "__data__"
108+
#define CONN_PLACEHOLDER "__conn_id__"
108109

109110
enum command_arg_type {
110111
const_type = 0,
111112
key_type = 1,
112113
data_type = 2,
113-
undefined_type = 3
114+
conn_id_type = 3,
115+
undefined_type = 4
114116
};
115117

116118
struct command_arg {

memtier_benchmark.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,8 @@ struct cg_thread {
12301230
m_protocol = protocol_factory(m_config->protocol);
12311231
assert(m_protocol != NULL);
12321232

1233-
m_cg = new client_group(m_config, m_protocol, m_obj_gen);
1233+
// Pass thread_id to client_group
1234+
m_cg = new client_group(m_config, m_protocol, m_obj_gen, m_thread_id);
12341235
}
12351236

12361237
~cg_thread()

protocol.cpp

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <assert.h>
12
/*
23
* Copyright (C) 2011-2017 Redis Labs Ltd.
34
*
@@ -175,7 +176,7 @@ class redis_protocol : public abstract_protocol {
175176
redis_protocol() : m_response_state(rs_initial), m_bulk_len(0), m_response_len(0), m_total_bulks_count(0), m_current_mbulk(NULL), m_resp3(false), m_attribute(false) { }
176177
virtual redis_protocol* clone(void) { return new redis_protocol(); }
177178
virtual int select_db(int db);
178-
virtual int authenticate(const char *credentials);
179+
virtual int authenticate(const char *user, const char *credentials);
179180
virtual int configure_protocol(enum PROTOCOL_TYPE type);
180181
virtual int write_command_cluster_slots();
181182
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
@@ -206,7 +207,7 @@ int redis_protocol::select_db(int db)
206207
return size;
207208
}
208209

209-
int redis_protocol::authenticate(const char *credentials)
210+
int redis_protocol::authenticate(const char *user, const char *credentials)
210211
{
211212
int size = 0;
212213
assert(credentials != NULL);
@@ -219,7 +220,6 @@ int redis_protocol::authenticate(const char *credentials)
219220
* contains a colon.
220221
*/
221222

222-
const char *user = NULL;
223223
const char *password;
224224

225225
if (credentials[0] == ':') {
@@ -229,12 +229,11 @@ int redis_protocol::authenticate(const char *credentials)
229229
if (!password) {
230230
password = credentials;
231231
} else {
232-
user = credentials;
233232
password++;
234233
}
235234
}
236235

237-
if (!user) {
236+
if (!user || strlen(user) == 0) {
238237
size = evbuffer_add_printf(m_write_buf,
239238
"*2\r\n"
240239
"$4\r\n"
@@ -243,17 +242,16 @@ int redis_protocol::authenticate(const char *credentials)
243242
"%s\r\n",
244243
strlen(password), password);
245244
} else {
246-
size_t user_len = password - user - 1;
245+
size_t user_len = strlen(user);
247246
size = evbuffer_add_printf(m_write_buf,
248247
"*3\r\n"
249248
"$4\r\n"
250249
"AUTH\r\n"
251250
"$%zu\r\n"
252-
"%.*s\r\n"
251+
"%s\r\n"
253252
"$%zu\r\n"
254253
"%s\r\n",
255254
user_len,
256-
(int) user_len,
257255
user,
258256
strlen(password),
259257
password);
@@ -723,8 +721,10 @@ bool redis_protocol::format_arbitrary_command(arbitrary_command &cmd) {
723721
benchmark_error_log("error: data placeholder can't combined with other data\n");
724722
return false;
725723
}
726-
727724
current_arg->type = data_type;
725+
} else if (current_arg->data.find(CONN_PLACEHOLDER) != std::string::npos) {
726+
// Allow conn_id placeholder to be combined with other text
727+
current_arg->type = conn_id_type;
728728
}
729729

730730
// we expect that first arg is the COMMAND name
@@ -761,7 +761,7 @@ class memcache_text_protocol : public abstract_protocol {
761761
memcache_text_protocol() : m_response_state(rs_initial), m_value_len(0), m_response_len(0) { }
762762
virtual memcache_text_protocol* clone(void) { return new memcache_text_protocol(); }
763763
virtual int select_db(int db);
764-
virtual int authenticate(const char *credentials);
764+
virtual int authenticate(const char *user, const char *credentials);
765765
virtual int configure_protocol(enum PROTOCOL_TYPE type);
766766
virtual int write_command_cluster_slots();
767767
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
@@ -782,7 +782,7 @@ int memcache_text_protocol::select_db(int db)
782782
assert(0);
783783
}
784784

785-
int memcache_text_protocol::authenticate(const char *credentials)
785+
int memcache_text_protocol::authenticate(const char *user, const char *credentials)
786786
{
787787
assert(0);
788788
}
@@ -983,7 +983,7 @@ class memcache_binary_protocol : public abstract_protocol {
983983
memcache_binary_protocol() : m_response_state(rs_initial), m_response_len(0) { }
984984
virtual memcache_binary_protocol* clone(void) { return new memcache_binary_protocol(); }
985985
virtual int select_db(int db);
986-
virtual int authenticate(const char *credentials);
986+
virtual int authenticate(const char *user, const char *credentials);
987987
virtual int configure_protocol(enum PROTOCOL_TYPE type);
988988
virtual int write_command_cluster_slots();
989989
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset);
@@ -1003,14 +1003,13 @@ int memcache_binary_protocol::select_db(int db)
10031003
assert(0);
10041004
}
10051005

1006-
int memcache_binary_protocol::authenticate(const char *credentials)
1006+
int memcache_binary_protocol::authenticate(const char *user, const char *credentials)
10071007
{
10081008
protocol_binary_request_no_extras req;
10091009
char nullbyte = '\0';
10101010
const char mechanism[] = "PLAIN";
10111011
int mechanism_len = sizeof(mechanism) - 1;
10121012
const char *colon;
1013-
const char *user;
10141013
int user_len;
10151014
const char *passwd;
10161015
int passwd_len;
@@ -1019,8 +1018,8 @@ int memcache_binary_protocol::authenticate(const char *credentials)
10191018
colon = strchr(credentials, ':');
10201019
assert(colon != NULL);
10211020

1022-
user = credentials;
1023-
user_len = colon - user;
1021+
// Use the user parameter instead of extracting from credentials
1022+
user_len = strlen(user);
10241023
passwd = colon + 1;
10251024
passwd_len = strlen(passwd);
10261025

protocol.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ class abstract_protocol {
183183
void set_keep_value(bool flag);
184184

185185
virtual int select_db(int db) = 0;
186-
virtual int authenticate(const char *credentials) = 0;
186+
virtual int authenticate(const char *user, const char *credentials) = 0;
187187
virtual int configure_protocol(enum PROTOCOL_TYPE type) = 0;
188188
virtual int write_command_cluster_slots() = 0;
189189
virtual int write_command_set(const char *key, int key_len, const char *value, int value_len, int expiry, unsigned int offset) = 0;

0 commit comments

Comments
 (0)