#include "def.h"
#include "pqp.h"
/* libc */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
/* posix */
#include <netinet/in.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netdb.h>
/* libs */
#include <wireguard.h>
#include <libpq-fe.h>
size_t dumpEndpoint(char* d, const wg_endpoint* const e) {
const struct sockaddr* addr;
size_t len;
switch(e->addr.sa_family) {
case AF_INET: addr = (void*)&(e->addr4); len = sizeof e->addr4; break;
case AF_INET6: addr = (void*)&(e->addr6); len = sizeof e->addr6; break;
case 0: strcpy(d, "<endpoint unset>"); return 16;
default: strcpy(d, "<bad endpoint>"); return 14;
}
char bip[256], bsrv[16];
getnameinfo(addr, len,
bip, sizeof bip,
bsrv, sizeof bsrv,
NI_NUMERICHOST | NI_NUMERICSERV);
return sprintf(d, "%s:%s", bip, bsrv);
}
size_t dumpAllowedIP(char* d, const wg_allowedip* aip) {
union {
struct sockaddr_in ip4;
struct sockaddr_in6 ip6;
} kinds;
size_t len;
switch(aip->family) {
case AF_INET: {
kinds.ip4 = (struct sockaddr_in) {
.sin_family = AF_INET,
.sin_port = 0,
.sin_addr = aip->ip4,
};
len = sizeof kinds.ip4;
break;}
case AF_INET6: {
kinds.ip6 = (struct sockaddr_in6) {
.sin6_family = AF_INET6,
.sin6_port = 0,
.sin6_flowinfo = 0,
.sin6_scope_id = 0,
.sin6_addr = aip->ip6,
};
len = sizeof kinds.ip6;
break;}
case 0: strcpy(d, "<no IP>"); return 7;
default: strcpy(d, "<non-IP address>"); return 16;
}
char bip[256], bsrv[2];
getnameinfo((void*)&kinds, len,
bip, sizeof bip,
bsrv, sizeof bsrv,
NI_NUMERICHOST | NI_NUMERICSERV);
return sprintf(d, "%s/%u", bip, aip->cidr);
}
bool compare_allowedip
( const wg_allowedip* const a,
const wg_allowedip* const b
) {
if(a -> family != b -> family) return false;
if(a -> cidr != b -> cidr) return false;
switch(a->family) {
case AF_INET:
if(a -> ip4.s_addr != b -> ip4.s_addr) return false;
break;
case AF_INET6:
if(memcmp(a -> ip6.s6_addr, b -> ip6.s6_addr, sizeof(a->ip6.s6_addr)) != 0)
return false;
break;
}
return true;
}
wg_allowedip
inet_to_allowedip(const char* data) {
pqp_sockstore ss;
if(!pqp_inet_read(data, &ss.sock))
_fatal("bad IP value in database");
wg_allowedip wgip = {
.family = ss.sock.sa_family,
.next_allowedip = null,
};
switch(ss.sock.sa_family) {
case AF_INET:
wgip.cidr = 32;
wgip.ip4 = ss.sock_in.sin_addr;
break;
case AF_INET6:
wgip.cidr = 128;
wgip.ip6 = ss.sock_in6.sin6_addr;
break;
default: _fatal("unhandled address family");
}
return wgip;
}
void wgd_free_peer(wg_peer* peer) {
wg_allowedip *allowedip, *na;
/* from ext/wglib/wireguard.c:1486 */
for (
allowedip = peer->first_allowedip,
na = allowedip ? allowedip->next_allowedip : NULL;
allowedip;
allowedip = na,
na = allowedip ? allowedip->next_allowedip : NULL
) free(allowedip);
/* end import */
free(peer);
}
/* linked list manipulation routines */
#define _ll_rec peer
#define _ll_box wg_device
#define _ll_obj wg_peer
#define _ll_iter wg_for_each_peer
#define _ll_ns wgd
#include "list.h"
#define _ll_rec allowedip
#define _ll_box wg_peer
#define _ll_obj wg_allowedip
#define _ll_iter wg_for_each_allowedip
#define _ll_ns wgd_peer
#include "list.h"
#if 0
void wgd_drop_peer(wg_device* dev, wg_peer* peer) {
if(dev -> first_peer == peer) {
if(dev -> last_peer == peer) {
dev -> first_peer = dev -> last_peer = null;
} else {
dev -> first_peer = peer -> next_peer;
}
} else {
wg_peer* p;
if(dev -> last_peer == peer) {
wg_for_each_peer(dev, p) {
if(p->next_peer == peer) {
dev -> last_peer = p;
p->next_peer = null;
goto found1;
}
}
_fatal("BUG in last peer deletion routine");
found1 :;
} else /* in the middle */ {
wg_for_each_peer(dev, p) {
if(p->next_peer == peer) {
p->next_peer = peer -> next_peer;
goto found2;
}
}
_fatal("BUG in peer deletion routine");
found2 :;
}
}
wgd_free_peer(peer);
}
void wgd_peer_drop_ip(wg_peer* peer, wg_allowedip* ip) {
if(peer -> first_allowedip == ip) {
if(peer -> last_allowedip == ip) {
peer -> first_allowedip = peer -> last_allowedip = null;
} else {
peer -> first_allowedip = peer -> next_allowedip;
}
} else {
wg_allowedip* a;
if(peer -> last_allowedip == ip) {
wg_for_each_allowedip(peer, a) {
if(a->next_allowedip == ip) {
peer -> last_allowedip = a;
a->next_allowedip = null;
goto found1;
}
}
_fatal("BUG in last aIP deletion routine");
found1 :;
} else /* in the middle */ {
wg_for_each_allowedip(peer, a) {
if(a->next_allowedip == ip) {
a->next_allowedip = ip -> next_allowedip;
goto found2;
}
}
_fatal("BUG in aIP deletion routine");
found2 :;
}
}
free(ip);
}
#endif
void syncauth(PGconn* db, const char* wgdev) {
wg_device* wg;
if (wg_get_device(&wg, wgdev))
_fatal("no wireguard device by that name");
bool dirty = false;
size_t peerc = 0;
{ wg_peer* p; wg_for_each_peer(wg, p) ++ peerc; };
bool valid_peers [peerc];
_zero(valid_peers);
PGresult* rows = PQexecPrepared(db, "get_hosts",
0, null, null, null, 1);
if(!(rows && PQresultStatus(rows) == PGRES_TUPLES_OK))
_fatal(PQerrorMessage(db));
size_t rowc = PQntuples(rows);
for(size_t i = 0; i < rowc; ++i) {
const char* key_b64 = PQgetvalue(rows, i, 0);
const char* aryraw = PQgetvalue(rows, i, 1);
pqp_array* ips = pqp_array_read(aryraw);
_dbgf("DB has peer %s", key_b64);
if(ips->ty != pq_array_inet)
_fatal("incorrect array type returned from DB");
wg_key key;
if (wg_key_from_base64(key, key_b64) < 0) {
_warnf("invalid key in database: %s", key_b64);
continue;
}
wg_peer* found = null;
{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
if(memcmp(p->public_key, key, sizeof key) == 0) {
_dbgf("validating peer %s", key_b64);
valid_peers[j] = true;
found = p;
break;
}
++j;}}
if (found) {
/* compare and update IPs if necessary */
bool goodIPs [ips -> sz]; _zero(goodIPs);
/* extant IPs that are not marked good by the
* end of the following loop must be deleted
* from memory */
size_t goodIPc = 0;
for (size_t j = 0; j < ips -> sz; ++j) {
char inetstr[256];
wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
dumpAllowedIP(inetstr, &aip);
_dbgf("IP PG%zu :: %s", j, inetstr);
size_t l = 0;
wg_allowedip* wgip;
bool foundIP = false;
wg_for_each_allowedip(found, wgip) {
if (compare_allowedip(&aip, wgip)) {
++goodIPc; goodIPs[l] = true;
foundIP = true;
}
++l;}
if(!foundIP) {
/* this IP hasn't been loaded into the
* kernel yet; upload it now */
_infof("inserting IP PG%zu %s", j, inetstr);
dirty = true;
}
}
if(goodIPc < ips -> sz) {
size_t l = 0;
wg_allowedip* wgip;
wg_for_each_allowedip(found, wgip) {
char inetstr[256];
dumpAllowedIP(inetstr, wgip);
_dbgf("IP WG%zu :: %s", l, inetstr);
if(!goodIPs[l]) {
/* this IP is stale, delete it */
_infof("deleting IP WG%zu %s", l, inetstr);
dirty = true;
}
++l;}
}
} else {
_infof("inserting key %s", key_b64);
dirty = true;
/* install new peer */
for (size_t j = 0; j < ips -> sz; ++j) {
char inetstr[256];
wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
dumpAllowedIP(inetstr, &aip);
_dbgf("new IP %zu :: %s", j, inetstr);
}
}
free(ips);
}
{ size_t i=0; wg_peer* p; wg_for_each_peer(wg, p) {
if(valid_peers[i] == false) {
char b64 [128];
wg_key_to_base64(b64, p->public_key);
_infof("dropping peer %s", b64);
wgd_drop_peer(wg, p);
dirty = true;
}
++i;}}
_dbg("final peer list:");
{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
char b64 [128];
wg_key_to_base64(b64, p->public_key);
_dbgf("P%zu :: %s", j, b64);
++j;}}
if(dirty) wg_set_device(wg);
PQclear(rows);
}
int main(int argc, char** argv) {
setvbuf(stderr, null, _IONBF, 0);
if (argc < 3) {
_fatal("missing device name");
}
const char* arg_mode = argv[1];
const char* arg_devname = argv[2];
/* mostly for the sake of debugging, allow the
* binary to be run from sudo without losing
* postgres peer credentials */
if(geteuid() == 0) {
char* suid = getenv("SUDO_UID");
char* susr = getenv("SUDO_USER");
if(suid) seteuid(atoi(suid));
if(susr) setenv("USER",getenv("SUDO_USER"), 1);
}
PGconn* db = PQconnectdb("dbname=domain");
if(PQstatus(db) != CONNECTION_OK)
_fatal(PQerrorMessage(db));
PGresult* q_get_hosts = PQprepare(db, "get_hosts",
"select h.ref, array_remove(array_agg(wgv4::inet)"
"|| array_agg(wgv6::inet), null)"
"from ns, hostref h "
"where ns.host = h.host and kind = 'pubkey' "
" group by h.host, h.ref;", 0, null);
/*"select ns.wgv4::inet, ns.wgv6::inet, h.ref from ns "
"right join hostref h "
"on h.host = ns.host "
"where h.kind = 'pubkey';"*/
if(!(q_get_hosts && PQresultStatus(q_get_hosts) == PGRES_COMMAND_OK))
_fatal(PQerrorMessage(db));
PQclear(q_get_hosts);
/* we're going to interact with WG now;
* get our superpowers back if we lost them */
{uid_t svuid;
getresuid(null, null, &svuid);
if (svuid == 0) setuid(0);}
if(strcmp(arg_mode, "sync") == 0) {
syncauth(db, arg_devname);
} else if(strcmp(arg_mode, "wait") == 0) {
/* foreground daemon */
} else if(strcmp(arg_mode, "fork") == 0) {
/* background daemon */
} else {
_fatal("valid modes are sync, wait, and fork");
}
/* other possibilities: a mode that generates an eventfd
* and provides it on fd4 to a subordinate process, or
* sends it with SCM_RIGHTS */
PQfinish(db);
return 0;
}