Skip to content

Commit b29b5af

Browse files
committed
pr: Allow TCP connections.
1 parent 1ccaaf2 commit b29b5af

File tree

4 files changed

+63
-16
lines changed

4 files changed

+63
-16
lines changed

include/vast/Conversion/Parser/Passes.td

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ def HLToParser : Pass<"vast-hl-to-parser", "core::ModuleOp"> {
1414
Option< "socket", "socket", "std::string", "",
1515
"Unix socket path to use for server."
1616
>,
17+
Option< "tcp_port", "tcp-port", "int", "-1",
18+
"TCP port to use for server."
19+
>,
20+
Option< "tcp_host", "tcp-host", "int", "0",
21+
"TCP host to use for server."
22+
>,
1723
Option< "yaml_out", "yaml-out", "std::string", "",
1824
"Path to YAML output file for models got from user."
1925
>

include/vast/server/io.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ namespace vast::server {
9595
void close() override;
9696

9797
static std::unique_ptr< sock_adapter > create_unix_socket(const std::string &path);
98+
static std::unique_ptr< sock_adapter >
99+
create_tcp_server_socket(uint32_t host, uint16_t port);
98100

99101
private:
100102
std::unique_ptr< struct impl > pimpl;

lib/vast/Conversion/Parser/ToParser.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,12 @@ namespace vast::conv {
11631163
vast::server::sock_adapter::create_unix_socket(socket), 1,
11641164
server_handler{ models }
11651165
);
1166+
} else if (tcp_port >= 0) {
1167+
server = std::make_shared<
1168+
vast::server::server< server_handler, get_function_model_request > >(
1169+
vast::server::sock_adapter::create_tcp_server_socket(tcp_host, tcp_port), 1,
1170+
server_handler{ models }
1171+
);
11661172
}
11671173
}
11681174

lib/vast/server/io.cpp

+49-16
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <stdexcept>
44
#include <system_error>
55

6+
#include <netinet/in.h>
67
#include <sys/socket.h>
78
#include <sys/un.h>
89
#include <unistd.h>
@@ -11,13 +12,20 @@ namespace vast::server {
1112
union addr {
1213
sockaddr base;
1314
sockaddr_un unix;
15+
sockaddr_in net;
1416
};
1517

1618
struct descriptor
1719
{
1820
int fd;
1921

20-
explicit descriptor(int fd) : fd(fd) {}
22+
explicit descriptor() : fd(-1) {}
23+
24+
explicit descriptor(int fd) : fd(fd) {
25+
if (fd < 0) {
26+
throw std::system_error(errno, std::generic_category());
27+
}
28+
}
2129

2230
descriptor(const descriptor &) = delete;
2331
descriptor &operator=(const descriptor &) = delete;
@@ -53,8 +61,8 @@ namespace vast::server {
5361
sock_adapter::~sock_adapter() = default;
5462

5563
void sock_adapter::close() {
56-
pimpl->clientd = descriptor{ -1 };
57-
pimpl->serverd = descriptor{ -1 };
64+
pimpl->clientd = descriptor{};
65+
pimpl->serverd = descriptor{};
5866
}
5967

6068
size_t sock_adapter::read_some(std::span< char > dst) {
@@ -73,15 +81,29 @@ namespace vast::server {
7381
return static_cast< size_t >(res);
7482
}
7583

84+
static descriptor bind_and_accept(
85+
descriptor &serverd, addr &sockaddr_server, size_t socklen_server,
86+
sockaddr *sockaddr_client, socklen_t *socklen_client
87+
) {
88+
int rc = bind(serverd, &sockaddr_server.base, static_cast< socklen_t >(socklen_server));
89+
if (rc < 0) {
90+
throw std::system_error(errno, std::generic_category());
91+
}
92+
93+
rc = listen(serverd, 1);
94+
if (rc < 0) {
95+
throw std::system_error(errno, std::generic_category());
96+
}
97+
98+
return descriptor{ accept(serverd, sockaddr_client, socklen_client) };
99+
}
100+
76101
std::unique_ptr< sock_adapter > sock_adapter::create_unix_socket(const std::string &path) {
77102
if (path.size() > (sizeof(sockaddr_un::sun_path) - 1)) {
78103
throw std::runtime_error("Unix socket pathname is too long");
79104
}
80105

81106
descriptor serverd{ socket(AF_UNIX, SOCK_STREAM, 0) };
82-
if (serverd < 0) {
83-
throw std::system_error(errno, std::generic_category());
84-
}
85107

86108
addr sock_addr{};
87109
sock_addr.unix.sun_family = AF_UNIX;
@@ -95,18 +117,29 @@ namespace vast::server {
95117
if (unlink(path.c_str()) < 0 && errno != ENOENT) {
96118
throw std::system_error(errno, std::generic_category());
97119
}
98-
int rc =
99-
bind(serverd, &sock_addr.base, static_cast< socklen_t >(SUN_LEN(&sock_addr.unix)));
100-
if (rc < 0) {
101-
throw std::system_error(errno, std::generic_category());
102-
}
103120

104-
rc = listen(serverd, 1);
105-
if (rc < 0) {
106-
throw std::system_error(errno, std::generic_category());
107-
}
121+
auto clientd =
122+
bind_and_accept(serverd, sock_addr, sizeof(sock_addr.unix), nullptr, nullptr);
123+
124+
return std::unique_ptr< sock_adapter >(new sock_adapter{
125+
std::make_unique< impl >(std::move(serverd), std::move(clientd)) });
126+
}
127+
128+
std::unique_ptr< sock_adapter >
129+
sock_adapter::create_tcp_server_socket(uint32_t host, uint16_t port) {
130+
descriptor serverd{ socket(AF_INET, SOCK_STREAM, 0) };
131+
132+
addr sock_addr{};
133+
sock_addr.net.sin_family = AF_INET;
134+
sock_addr.net.sin_addr.s_addr = htonl(host);
135+
sock_addr.net.sin_port = htons(port);
136+
137+
addr client_addr{};
138+
socklen_t client_addr_size = sizeof(client_addr.net);
108139

109-
descriptor clientd{ accept(serverd, nullptr, nullptr) };
140+
auto clientd = bind_and_accept(
141+
serverd, sock_addr, sizeof(sock_addr.net), &client_addr.base, &client_addr_size
142+
);
110143

111144
return std::unique_ptr< sock_adapter >(new sock_adapter{
112145
std::make_unique< impl >(std::move(serverd), std::move(clientd)) });

0 commit comments

Comments
 (0)