diff --git a/sockssrv.c b/sockssrv.c index dbd1f52..0e067ae 100644 --- a/sockssrv.c +++ b/sockssrv.c @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -380,10 +381,11 @@ static int usage(void) { dprintf(2, "MicroSocks SOCKS5 Server\n" "------------------------\n" - "usage: microsocks -1 -q -i listenip -p port -u user -P password -b bindaddr\n" + "usage: microsocks -1 -q -i listenip -p port -u user -P password -b bindaddr -t timeout\n" "all arguments are optional.\n" "by default listenip is 0.0.0.0 and port 1080.\n\n" "option -q disables logging.\n" + "option -t specifies an idle exit timeout in seconds. default is to wait forever\n" "option -b specifies which ip outgoing connections are bound to\n" "option -1 activates auth_once mode: once a specific ip address\n" "authed successfully with user/pass, it is added to a whitelist\n" @@ -405,7 +407,8 @@ int main(int argc, char** argv) { int ch; const char *listenip = "0.0.0.0"; unsigned port = 1080; - while((ch = getopt(argc, argv, ":1qb:i:p:u:P:")) != -1) { + unsigned idle_timeout = 0; + while((ch = getopt(argc, argv, ":1qb:i:p:t:u:P:")) != -1) { switch(ch) { case '1': auth_ips = sblist_new(sizeof(union sockaddr_union), 8); @@ -430,6 +433,9 @@ int main(int argc, char** argv) { case 'p': port = atoi(optarg); break; + case 't': + idle_timeout = atoi(optarg); + break; case ':': dprintf(2, "error: option -%c requires an operand\n", optopt); /* fall through */ @@ -454,8 +460,34 @@ int main(int argc, char** argv) { } server = &s; + if (idle_timeout && fcntl(s.fd, F_SETFL, fcntl(s.fd, F_GETFL, 0) | O_NONBLOCK)) { + perror("fcntl O_NONBLOCK"); + return 1; + } + while(1) { - collect(threads); + while(1) { + collect(threads); + if (!idle_timeout) break; + struct pollfd fds[1] = { + [0] = {.fd = s.fd, .events = POLLIN}, + }; + switch(poll(fds, 1, idle_timeout*1000)) { + case 0: + if (sblist_getsize(threads) == 0) { + dprintf(2, "idle timeout exit\n"); + return 0; + } + continue; + case -1: + if(errno != EINTR && errno != EAGAIN) { + perror("poll"); + return 1; + } + continue; + } + break; + } struct client c; struct thread *curr = malloc(sizeof (struct thread)); if(!curr) goto oom; @@ -466,6 +498,12 @@ int main(int argc, char** argv) { usleep(FAILURE_TIMEOUT); continue; } + if (idle_timeout && fcntl(c.fd, F_SETFL, fcntl(c.fd, F_GETFL, 0) & ~O_NONBLOCK)) { + perror("fcntl ~O_NONBLOCK"); + close(c.fd); + free(curr); + continue; + } curr->client = c; if(!sblist_add(threads, &curr)) { close(curr->client.fd);