#include "def.h"
#include "pqp.h"
/* libc */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
/* posix */
#include <netinet/in.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netdb.h>
#include <poll.h>
#if __linux__
# include <sys/signalfd.h>
# include <signal.h>
#endif
/* libs */
#include <wireguard.h>
#include "wglist.h"
/* wireguard uses messy linked lists but doesn't
* provide any routines for manipulating them;
* wglist.h fills in the gap */
#include <libpq-fe.h>
size_t dumpEndpoint(char* d, const wg_endpoint* const e) {
const struct sockaddr* addr;
size_t len;
switch(e->addr.sa_family) {
case AF_INET: addr = (void*)&(e->addr4); len = sizeof e->addr4; break;
case AF_INET6: addr = (void*)&(e->addr6); len = sizeof e->addr6; break;
case 0: strcpy(d, "<endpoint unset>"); return 16;
default: strcpy(d, "<bad endpoint>"); return 14;
}
char bip[256], bsrv[16];
getnameinfo(addr, len,
bip, sizeof bip,
bsrv, sizeof bsrv,
NI_NUMERICHOST | NI_NUMERICSERV);
return sprintf(d, "%s:%s", bip, bsrv);
}
size_t dumpAllowedIP(char* d, const wg_allowedip* aip) {
union {
struct sockaddr_in ip4;
struct sockaddr_in6 ip6;
} kinds;
size_t len;
switch(aip->family) {
case AF_INET: {
kinds.ip4 = (struct sockaddr_in) {
.sin_family = AF_INET,
.sin_port = 0,
.sin_addr = aip->ip4,
};
len = sizeof kinds.ip4;
break;}
case AF_INET6: {
kinds.ip6 = (struct sockaddr_in6) {
.sin6_family = AF_INET6,
.sin6_port = 0,
.sin6_flowinfo = 0,
.sin6_scope_id = 0,
.sin6_addr = aip->ip6,
};
len = sizeof kinds.ip6;
break;}
case 0: strcpy(d, "<no IP>"); return 7;
default: strcpy(d, "<non-IP address>"); return 16;
}
char bip[256], bsrv[2];
getnameinfo((void*)&kinds, len,
bip, sizeof bip,
bsrv, sizeof bsrv,
NI_NUMERICHOST | NI_NUMERICSERV);
return sprintf(d, "%s/%u", bip, aip->cidr);
}
bool compare_allowedip
( const wg_allowedip* const a,
const wg_allowedip* const b
) {
if(a -> family != b -> family) return false;
if(a -> cidr != b -> cidr) return false;
switch(a->family) {
case AF_INET:
if(a -> ip4.s_addr != b -> ip4.s_addr) return false;
break;
case AF_INET6:
if(memcmp(a -> ip6.s6_addr, b -> ip6.s6_addr, sizeof(a->ip6.s6_addr)) != 0)
return false;
break;
}
return true;
}
wg_allowedip
inet_to_allowedip(const char* data) {
pqp_sockstore ss;
if(!pqp_inet_read(data, &ss.sock))
_fatal("bad IP value in database");
wg_allowedip wgip = {
.family = ss.sock.sa_family,
.next_allowedip = null,
};
switch(ss.sock.sa_family) {
case AF_INET:
wgip.cidr = 32;
wgip.ip4 = ss.sock_in.sin_addr;
break;
case AF_INET6:
wgip.cidr = 128;
wgip.ip6 = ss.sock_in6.sin6_addr;
break;
default: _fatal("unhandled address family");
}
return wgip;
}
void wgd_free_peer(wg_peer* peer) {
wg_allowedip *allowedip, *na;
/* from ext/wglib/wireguard.c:1486 */
for (
allowedip = peer->first_allowedip,
na = allowedip ? allowedip->next_allowedip : NULL;
allowedip;
allowedip = na,
na = allowedip ? allowedip->next_allowedip : NULL
) free(allowedip);
/* end import */
free(peer);
}
/* linked list manipulation routines */
#if 0
void wgd_drop_peer(wg_device* dev, wg_peer* peer) {
if(dev -> first_peer == peer) {
if(dev -> last_peer == peer) {
dev -> first_peer = dev -> last_peer = null;
} else {
dev -> first_peer = peer -> next_peer;
}
} else {
wg_peer* p;
if(dev -> last_peer == peer) {
wg_for_each_peer(dev, p) {
if(p->next_peer == peer) {
dev -> last_peer = p;
p->next_peer = null;
goto found1;
}
}
_fatal("BUG in last peer deletion routine");
found1 :;
} else /* in the middle */ {
wg_for_each_peer(dev, p) {
if(p->next_peer == peer) {
p->next_peer = peer -> next_peer;
goto found2;
}
}
_fatal("BUG in peer deletion routine");
found2 :;
}
}
wgd_free_peer(peer);
}
void wgd_peer_drop_ip(wg_peer* peer, wg_allowedip* ip) {
if(peer -> first_allowedip == ip) {
if(peer -> last_allowedip == ip) {
peer -> first_allowedip = peer -> last_allowedip = null;
} else {
peer -> first_allowedip = peer -> next_allowedip;
}
} else {
wg_allowedip* a;
if(peer -> last_allowedip == ip) {
wg_for_each_allowedip(peer, a) {
if(a->next_allowedip == ip) {
peer -> last_allowedip = a;
a->next_allowedip = null;
goto found1;
}
}
_fatal("BUG in last aIP deletion routine");
found1 :;
} else /* in the middle */ {
wg_for_each_allowedip(peer, a) {
if(a->next_allowedip == ip) {
a->next_allowedip = ip -> next_allowedip;
goto found2;
}
}
_fatal("BUG in aIP deletion routine");
found2 :;
}
}
free(ip);
}
#endif
void syncauth(PGconn* db, const char* wgdev) {
wg_device* wg;
if (wg_get_device(&wg, wgdev))
_fatal("no wireguard device by that name");
bool dirty = false;
size_t peerc = 0;
{ wg_peer* p; wg_for_each_peer(wg, p) ++ peerc; };
bool valid_peers [peerc];
_zero(valid_peers);
PGresult* rows = PQexecPrepared(db, "get_hosts",
0, null, null, null, 1);
if(!(rows && PQresultStatus(rows) == PGRES_TUPLES_OK))
_fatal(PQerrorMessage(db));
size_t rowc = PQntuples(rows);
for(size_t i = 0; i < rowc; ++i) {
const char* key_b64 = PQgetvalue(rows, i, 0);
const char* aryraw = PQgetvalue(rows, i, 1);
pqp_array* ips = pqp_array_read(aryraw);
_dbgf("DB has peer %s", key_b64);
if(ips->ty != pq_array_inet)
_fatal("incorrect array type returned from DB");
wg_key key;
if (wg_key_from_base64(key, key_b64) < 0) {
_warnf("invalid key in database: %s", key_b64);
continue;
}
wg_peer* found = null;
{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
if(memcmp(p->public_key, key, sizeof key) == 0) {
_dbgf("validating peer %s", key_b64);
valid_peers[j] = true;
found = p;
break;
}
++j;}}
if (found) {
/* compare and update IPs if necessary */
size_t wgIPc = 0;
{ wg_allowedip* a; wg_for_each_allowedip(found, a) ++wgIPc; };
bool goodIPs [wgIPc]; _zero(goodIPs);
/* extant IPs that are not marked good by the
* end of the following loop must be deleted
* from memory */
size_t goodIPc = 0;
for (size_t j = 0; j < ips -> sz; ++j) {
char inetstr[256];
wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
dumpAllowedIP(inetstr, &aip);
_dbgf("IP PG%zu :: %s", j, inetstr);
size_t l = 0;
wg_allowedip* wgip;
bool foundIP = false;
wg_for_each_allowedip(found, wgip) {
if (compare_allowedip(&aip, wgip)) {
++goodIPc; goodIPs[l] = true;
foundIP = true;
}
++l;}
if(foundIP == false) {
/* this IP hasn't been loaded into the
* kernel yet; upload it now */
_infof("inserting IP PG%zu %s", j, inetstr);
found -> flags |= WGPEER_REPLACE_ALLOWEDIPS;
wg_allowedip* nip = wgd_peer_new_allowedip(found);
memcpy(nip, &aip, sizeof aip);
dirty = true;
}
}
if(goodIPc < wgIPc) {
size_t l = 0;
wg_allowedip* wgip;
wg_for_each_allowedip(found, wgip) {
char inetstr[256];
dumpAllowedIP(inetstr, wgip);
_dbgf("IP WG%zu :: %s", l, inetstr);
if(l<goodIPc && !goodIPs[l]) {
/* this IP is stale, delete it */
_infof("deleting IP WG%zu %s", l, inetstr);
wgd_peer_drop_allowedip(found, wgip);
found -> flags |= WGPEER_REPLACE_ALLOWEDIPS;
dirty = true;
}
++l;}
}
} else {
_infof("inserting key %s", key_b64);
dirty = true;
/* install new peer */
wg_peer* np = wgd_new_peer(wg);
np -> flags = WGPEER_HAS_PUBLIC_KEY;
memcpy(np -> public_key, key, sizeof key);
for (size_t j = 0; j < ips -> sz; ++j) {
char inetstr[256];
wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
dumpAllowedIP(inetstr, &aip);
_dbgf("new IP %zu :: %s", j, inetstr);
wg_allowedip* nip = wgd_peer_new_allowedip(np);
memcpy(nip, &aip, sizeof aip);
}
}
free(ips);
}
{ size_t i=0; wg_peer* p; wg_for_each_peer(wg, p) {
if(i<peerc && valid_peers[i] == false) {
char b64 [128];
wg_key_to_base64(b64, p->public_key);
_infof("dropping peer %s", b64);
//wgd_drop_peer(wg, p);
p -> flags |= WGPEER_REMOVE_ME;
dirty = true;
}
++i;}}
_dbg("final peer list:");
{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
char b64 [128];
wg_key_to_base64(b64, p->public_key);
_dbgf("P%zu :: %s%s", j, b64,
p->flags & WGPEER_REMOVE_ME ? " [DELETE]" :
p->flags & WGPEER_REPLACE_ALLOWEDIPS ? " [CHGIP]" : "");
++j;}}
dirty = true;
if(dirty) {
int e = wg_set_device(wg);
if(e != 0)
_fatalf("could not set wg device (error %i)", -e);
}
PQclear(rows);
}
void daemonmain(PGconn* db, const char* wgdev) {
PGresult* subscribe = PQexec(db,
"listen sync_vpn;"
"listen sync_priv;");
if (PQresultStatus(subscribe) != PGRES_COMMAND_OK)
_warn("could not subscribe to DB notification channels");
PQclear(subscribe);
int pqfd = PQsocket(db);
#if __linux__
sigset_t sigs;
sigemptyset(&sigs);
sigaddset(&sigs, SIGHUP);
sigaddset(&sigs, SIGTERM);
sigaddset(&sigs, SIGINT);
sigprocmask(SIG_BLOCK, &sigs, null);
int sigfd = signalfd(-1, &sigs, SFD_CLOEXEC);
#endif
struct pollfd polls[] = {
{ .fd = pqfd, .events = POLLIN, .revents = 0 },
#if __linux__
{ .fd = sigfd, .events = POLLIN, .revents = 0 },
#endif
};
for (;;) {
int p = poll(polls, _sz(polls), -1);
if (p > 0) {
bool didSync = false;
switch (polls[0].revents) {
case 0: break;
case POLLHUP:
_fatal("lost DB connection; terminating");
case POLLIN: {
PQconsumeInput(db);
for (;;) {
PGnotify* n = PQnotifies(db);
if(n == null) break;
if(strcmp(n->relname, "sync_vpn") == 0
|| strcmp(n->relname, "sync_priv") == 0) {
if(!didSync) {
syncauth(db, wgdev);
didSync = true;
}
}
}
}
}
#if __linux__
switch (polls[1].revents) {
case 0: break;
case POLLIN: {
struct signalfd_siginfo si;
read(sigfd, &si, sizeof si);
if(si.ssi_signo == SIGHUP && !didSync) {
syncauth(db, wgdev);
didSync = true;
} else if (si.ssi_signo == SIGTERM || si.ssi_signo == SIGINT) {
goto poll_end;
}
};
}
#endif
}
}
poll_end :;
_info("shutting down");
#if __linux__
close(sigfd);
#endif
}
int main(int argc, char** argv) {
setvbuf(stderr, null, _IONBF, 0);
if (argc < 3) {
_fatal("missing device name");
}
const char* arg_mode = argv[1];
const char* arg_devname = argv[2];
/* mostly for the sake of debugging, allow the
* binary to be run from sudo without losing
* postgres peer credentials */
if(geteuid() == 0) {
char* suid = getenv("SUDO_UID");
char* susr = getenv("SUDO_USER");
if(suid) seteuid(atoi(suid));
if(susr) setenv("USER",getenv("SUDO_USER"), 1);
}
char* connstr = getenv("wgsync_conn");
if(connstr == null) _fatal("no connection string supplied");
PGconn* db = PQconnectdb(connstr);
if(PQstatus(db) != CONNECTION_OK)
_fatal(PQerrorMessage(db));
PGresult* q_get_hosts = PQprepare(db, "get_hosts",
"select h.ref, array_remove(array_agg(wgv4::inet)"
"|| array_agg(wgv6::inet), null)"
"from ns, hostref h "
"where ns.host = h.host and kind = 'pubkey' "
" group by h.host, h.ref;", 0, null);
/*"select ns.wgv4::inet, ns.wgv6::inet, h.ref from ns "
"right join hostref h "
"on h.host = ns.host "
"where h.kind = 'pubkey';"*/
if(!(q_get_hosts && PQresultStatus(q_get_hosts) == PGRES_COMMAND_OK))
_fatal(PQerrorMessage(db));
PQclear(q_get_hosts);
/* we're going to interact with WG now;
* get our superpowers back if we lost them */
{uid_t svuid;
getresuid(null, null, &svuid);
if (svuid == 0) setuid(0);}
if(strcmp(arg_mode, "sync") == 0) {
syncauth(db, arg_devname);
} else if(strcmp(arg_mode, "wait") == 0 ||
strcmp(arg_mode, "syncwait") == 0 ||
strcmp(arg_mode, "fork") == 0 ||
strcmp(arg_mode, "syncfork") == 0) {
if(strncmp(arg_mode, "sync", 4) == 0)
syncauth(db, arg_devname);
/* maybe background daemon */
if(strcmp(arg_mode, "fork") == 0 ||
strcmp(arg_mode, "syncfork") == 0) {
if (daemon(1,1) == -1)
_fatal("cannot daemonize");
}
daemonmain(db, arg_devname);
} else {
_fatal("valid modes are sync, wait, syncwait, fork, and syncfork");
}
/* other possibilities: a mode that generates an eventfd
* and provides it on fd4 to a subordinate process, or
* sends it with SCM_RIGHTS */
PQfinish(db);
return 0;
}