util  wgsync.c

File wgsync/src/wgsync.c from the latest check-in


#include "def.h"
#include "pqp.h"

/* libc */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

/* posix */
#include <netinet/in.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netdb.h>
#include <poll.h>

#if __linux__
#	include <sys/signalfd.h>
#	include <signal.h>
#endif

/* libs */
#include <wireguard.h>
#include "wglist.h"
	/* wireguard uses messy linked lists but doesn't
	 * provide any routines for manipulating them;
	 * wglist.h fills in the gap */

#include <libpq-fe.h>


size_t dumpEndpoint(char* d, const wg_endpoint* const e) {
	const struct sockaddr* addr;
	size_t len;
	switch(e->addr.sa_family) {
		case AF_INET: addr = (void*)&(e->addr4); len = sizeof e->addr4; break;
		case AF_INET6: addr = (void*)&(e->addr6); len = sizeof e->addr6; break;
		case 0: strcpy(d, "<endpoint unset>"); return 16;
		default: strcpy(d, "<bad endpoint>"); return 14;
	}
	char bip[256], bsrv[16];
	getnameinfo(addr, len,
			bip,  sizeof bip,
			bsrv, sizeof bsrv,
			NI_NUMERICHOST | NI_NUMERICSERV);
	return sprintf(d, "%s:%s", bip, bsrv);
}

size_t dumpAllowedIP(char* d, const wg_allowedip* aip) {
		union {
			struct sockaddr_in  ip4;
			struct sockaddr_in6 ip6;
		} kinds;
		size_t len;
		switch(aip->family) {
			case AF_INET: {
				kinds.ip4 = (struct sockaddr_in) {
					.sin_family = AF_INET,
					.sin_port = 0,
					.sin_addr = aip->ip4,
				};
				len = sizeof kinds.ip4;
			break;}
			case AF_INET6: {
				kinds.ip6 = (struct sockaddr_in6) {
					.sin6_family = AF_INET6,
					.sin6_port = 0,
					.sin6_flowinfo = 0,
					.sin6_scope_id = 0,
					.sin6_addr = aip->ip6,
				};
				len = sizeof kinds.ip6;
			break;}
			case 0: strcpy(d, "<no IP>"); return 7;
			default: strcpy(d, "<non-IP address>"); return 16;
		}
		char bip[256], bsrv[2];
		getnameinfo((void*)&kinds, len,
				bip,  sizeof bip,
				bsrv, sizeof bsrv,
				NI_NUMERICHOST | NI_NUMERICSERV);
		return sprintf(d, "%s/%u", bip, aip->cidr);
}

bool compare_allowedip
(	const wg_allowedip* const a,
	const wg_allowedip* const b
) {
	if(a -> family != b -> family) return false;
	if(a -> cidr != b -> cidr) return false;
	switch(a->family) {
		case AF_INET:
			if(a -> ip4.s_addr != b -> ip4.s_addr) return false;
			break;
		case AF_INET6:
			if(memcmp(a -> ip6.s6_addr, b -> ip6.s6_addr, sizeof(a->ip6.s6_addr)) != 0)
				return false;
			break;
	}
	return true;
}

wg_allowedip
inet_to_allowedip(const char* data) {
	pqp_sockstore ss;
	if(!pqp_inet_read(data, &ss.sock))
		_fatal("bad IP value in database");

	wg_allowedip wgip = {
		.family = ss.sock.sa_family,
		.next_allowedip = null,
	};
	switch(ss.sock.sa_family) {
		case AF_INET:
			wgip.cidr = 32;
			wgip.ip4 = ss.sock_in.sin_addr;
			break;
		case AF_INET6:
			wgip.cidr = 128;
			wgip.ip6 = ss.sock_in6.sin6_addr;
			break;
		default: _fatal("unhandled address family");
	}
	return wgip;
}

void wgd_free_peer(wg_peer* peer) {
	wg_allowedip *allowedip, *na;
	/* from ext/wglib/wireguard.c:1486 */
	for (
		allowedip = peer->first_allowedip,
			na = allowedip ? allowedip->next_allowedip : NULL; 
		allowedip;
		allowedip = na,
			na = allowedip ? allowedip->next_allowedip : NULL
	) free(allowedip);
	/* end import */
	free(peer);
}

/* linked list manipulation routines */


#if 0
void wgd_drop_peer(wg_device* dev, wg_peer* peer) {
	if(dev -> first_peer == peer) {
		if(dev -> last_peer == peer) {
			dev -> first_peer = dev -> last_peer = null;
		} else {
			dev -> first_peer = peer -> next_peer;
		}
	} else {
		wg_peer* p;
		if(dev -> last_peer == peer) {
			wg_for_each_peer(dev, p) {
				if(p->next_peer == peer) {
					dev -> last_peer = p;
					p->next_peer = null;
					goto found1;
				}
			}
			_fatal("BUG in last peer deletion routine");
			found1 :;
		} else /* in the middle */ {
			wg_for_each_peer(dev, p) {
				if(p->next_peer == peer) {
					p->next_peer = peer -> next_peer;
					goto found2;
				}
			}
			_fatal("BUG in peer deletion routine");
			found2 :;
		}
	}
	wgd_free_peer(peer);
}
void wgd_peer_drop_ip(wg_peer* peer, wg_allowedip* ip) {
	if(peer -> first_allowedip == ip) {
		if(peer -> last_allowedip == ip) {
			peer -> first_allowedip = peer -> last_allowedip = null;
		} else {
			peer -> first_allowedip = peer -> next_allowedip;
		}
	} else {
		wg_allowedip* a;
		if(peer -> last_allowedip == ip) {
			wg_for_each_allowedip(peer, a) {
				if(a->next_allowedip == ip) {
					peer -> last_allowedip = a;
					a->next_allowedip = null;
					goto found1;
				}
			}
			_fatal("BUG in last aIP deletion routine");
			found1 :;
		} else /* in the middle */ {
			wg_for_each_allowedip(peer, a) {
				if(a->next_allowedip == ip) {
					a->next_allowedip = ip -> next_allowedip;
					goto found2;
				}
			}
			_fatal("BUG in aIP deletion routine");
			found2 :;
		}
	}
	free(ip);
}
#endif

void syncauth(PGconn* db, const char* wgdev) {
	wg_device* wg;
	if (wg_get_device(&wg, wgdev))
		_fatal("no wireguard device by that name");

	bool dirty = false;
	size_t peerc = 0;
	{ wg_peer* p; wg_for_each_peer(wg, p) ++ peerc; };

	bool valid_peers [peerc];
	_zero(valid_peers);

	PGresult* rows = PQexecPrepared(db, "get_hosts",
			0, null, null, null, 1);
	if(!(rows && PQresultStatus(rows) == PGRES_TUPLES_OK))
		_fatal(PQerrorMessage(db));

	size_t rowc = PQntuples(rows);
	for(size_t i = 0; i < rowc; ++i) {
		const char* key_b64 = PQgetvalue(rows, i, 0);

		const char* aryraw = PQgetvalue(rows, i, 1);
		pqp_array* ips = pqp_array_read(aryraw);
		_dbgf("DB has peer %s", key_b64);
		if(ips->ty != pq_array_inet)
			_fatal("incorrect array type returned from DB");

		wg_key key;
		if (wg_key_from_base64(key, key_b64) < 0) {
			_warnf("invalid key in database: %s", key_b64);
			continue;
		}

		wg_peer* found = null;
		{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
			if(memcmp(p->public_key, key, sizeof key) == 0) {
				_dbgf("validating peer %s", key_b64);
				valid_peers[j] = true;
				found = p;
				break;
			}
		++j;}}

		if (found) {
			/* compare and update IPs if necessary */

			size_t wgIPc = 0;
			{ wg_allowedip* a; wg_for_each_allowedip(found, a) ++wgIPc; };
			bool goodIPs [wgIPc]; _zero(goodIPs);
			/* extant IPs that are not marked good by the
			 * end of the following loop must be deleted
			 * from memory */
			size_t goodIPc = 0;
			for (size_t j = 0; j < ips -> sz; ++j) {
				char inetstr[256];
				wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
				dumpAllowedIP(inetstr, &aip);
				_dbgf("IP PG%zu :: %s", j, inetstr);

				size_t l = 0;
				wg_allowedip* wgip;
				bool foundIP = false;
				wg_for_each_allowedip(found, wgip) {
					if (compare_allowedip(&aip, wgip)) {
						++goodIPc; goodIPs[l] = true;
						foundIP = true;
					}
				++l;}
				
				if(foundIP == false) {
					/* this IP hasn't been loaded into the
					 * kernel yet; upload it now */
					_infof("inserting IP PG%zu %s", j, inetstr);
					found -> flags |= WGPEER_REPLACE_ALLOWEDIPS;
					wg_allowedip* nip = wgd_peer_new_allowedip(found);
					memcpy(nip, &aip, sizeof aip);
					dirty = true;
				}
			}

			if(goodIPc < wgIPc) {
				size_t l = 0;
				wg_allowedip* wgip;
				wg_for_each_allowedip(found, wgip) {
					char inetstr[256];
					dumpAllowedIP(inetstr, wgip);
					_dbgf("IP WG%zu :: %s", l, inetstr);
					if(l<goodIPc && !goodIPs[l]) {
						/* this IP is stale, delete it */
						_infof("deleting IP WG%zu %s", l, inetstr);
						wgd_peer_drop_allowedip(found, wgip);
						found -> flags |= WGPEER_REPLACE_ALLOWEDIPS;
						dirty = true;
					}
				++l;}
			}
		} else {
			_infof("inserting key %s", key_b64);
			dirty = true;
			/* install new peer */
			wg_peer* np = wgd_new_peer(wg);
			np -> flags = WGPEER_HAS_PUBLIC_KEY;
			memcpy(np -> public_key, key, sizeof key);

			for (size_t j = 0; j < ips -> sz; ++j) {
				char inetstr[256];
				wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
				dumpAllowedIP(inetstr, &aip);
				_dbgf("new IP %zu :: %s", j, inetstr);
				wg_allowedip* nip = wgd_peer_new_allowedip(np);
				memcpy(nip, &aip, sizeof aip);
			}
		}

		free(ips);
	}
	{ size_t i=0; wg_peer* p; wg_for_each_peer(wg, p) {
		if(i<peerc && valid_peers[i] == false) {
			char b64 [128];
			wg_key_to_base64(b64, p->public_key);
			_infof("dropping peer %s", b64);
			//wgd_drop_peer(wg, p);
			p -> flags |= WGPEER_REMOVE_ME;
			dirty = true;
		}
	++i;}}

	_dbg("final peer list:");
	{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
		char b64 [128];
		wg_key_to_base64(b64, p->public_key);
		_dbgf("P%zu :: %s%s", j, b64,
			p->flags & WGPEER_REMOVE_ME          ? " [DELETE]" :
			p->flags & WGPEER_REPLACE_ALLOWEDIPS ? " [CHGIP]" : "");
	++j;}}
	
	dirty = true;
	if(dirty) {
		int e = wg_set_device(wg);
		if(e != 0) 
			_fatalf("could not set wg device (error %i)", -e);
	}

	PQclear(rows);
}

void daemonmain(PGconn* db, const char* wgdev) {
	PGresult* subscribe = PQexec(db,
		"listen sync_vpn;"
		"listen sync_priv;");
	if (PQresultStatus(subscribe) != PGRES_COMMAND_OK)
		_warn("could not subscribe to DB notification channels");
	PQclear(subscribe);

	int pqfd = PQsocket(db);
#if __linux__
	sigset_t sigs;
	sigemptyset(&sigs);
	sigaddset(&sigs, SIGHUP);
	sigaddset(&sigs, SIGTERM);
	sigaddset(&sigs, SIGINT);
	sigprocmask(SIG_BLOCK, &sigs, null);
	int sigfd = signalfd(-1, &sigs, SFD_CLOEXEC);
#endif

	struct pollfd polls[] = {
		{ .fd = pqfd, .events = POLLIN, .revents = 0 },
#if __linux__
		{ .fd = sigfd, .events = POLLIN, .revents = 0 },
#endif
	};

	for (;;) {
		int p = poll(polls, _sz(polls), -1);
		if (p > 0) {
			bool didSync = false;
			switch (polls[0].revents) {
				case 0: break;
				case POLLHUP:
					_fatal("lost DB connection; terminating");
				case POLLIN: {
					PQconsumeInput(db);
					for (;;) {
						PGnotify* n = PQnotifies(db);
						if(n == null) break;
						if(strcmp(n->relname, "sync_vpn") == 0
						|| strcmp(n->relname, "sync_priv") == 0) {
							if(!didSync) {
								syncauth(db, wgdev);
								didSync = true;
							}
						}
					}
				}
			}
#if __linux__
			switch (polls[1].revents) {
				case 0: break;
				case POLLIN: {
					struct signalfd_siginfo si;
					read(sigfd, &si, sizeof si);

					if(si.ssi_signo == SIGHUP && !didSync) {
						syncauth(db, wgdev);
						didSync = true;
					} else if (si.ssi_signo == SIGTERM || si.ssi_signo == SIGINT) {
						goto poll_end;
					}
				};
			}
#endif
		}
	}

	poll_end :;

	_info("shutting down");
#if __linux__
	close(sigfd);
#endif
}

int main(int argc, char** argv) {
	setvbuf(stderr, null, _IONBF, 0);
	if (argc < 3) {
		_fatal("missing device name");
	}

	const char* arg_mode = argv[1];
	const char* arg_devname = argv[2];

	/* mostly for the sake of debugging, allow the
	 * binary to be run from sudo without losing
	 * postgres peer credentials */
	if(geteuid() == 0) {
		char* suid = getenv("SUDO_UID");
		char* susr = getenv("SUDO_USER");
		if(suid) seteuid(atoi(suid));
		if(susr) setenv("USER",getenv("SUDO_USER"), 1);
	}

	char* connstr = getenv("wgsync_conn");
	if(connstr == null) _fatal("no connection string supplied");
	PGconn* db = PQconnectdb(connstr);
	if(PQstatus(db) != CONNECTION_OK) 
		_fatal(PQerrorMessage(db));

	PGresult* q_get_hosts = PQprepare(db, "get_hosts",
		"select h.ref, array_remove(array_agg(wgv4::inet)"
		                        "|| array_agg(wgv6::inet), null)"
			"from ns, hostref h "
			"where ns.host = h.host and kind = 'pubkey' "
			" group by h.host, h.ref;", 0, null);
		/*"select ns.wgv4::inet, ns.wgv6::inet, h.ref from ns "
			"right join hostref h "
				"on h.host = ns.host "
			"where h.kind = 'pubkey';"*/
	if(!(q_get_hosts && PQresultStatus(q_get_hosts) == PGRES_COMMAND_OK))
		_fatal(PQerrorMessage(db));
	PQclear(q_get_hosts);

	/* we're going to interact with WG now;
	 * get our superpowers back if we lost them */
	{uid_t svuid;
	getresuid(null, null, &svuid);
	if (svuid == 0) setuid(0);}

	if(strcmp(arg_mode, "sync") == 0) {
		syncauth(db, arg_devname);
	} else if(strcmp(arg_mode, "wait")     == 0 ||
	          strcmp(arg_mode, "syncwait") == 0 ||
	          strcmp(arg_mode, "fork")     == 0 ||
	          strcmp(arg_mode, "syncfork") == 0) {

		if(strncmp(arg_mode, "sync", 4) == 0)
			syncauth(db, arg_devname);

		/* maybe background daemon */
		if(strcmp(arg_mode, "fork")     == 0 ||
	       strcmp(arg_mode, "syncfork") == 0) {
			if (daemon(1,1) == -1)
				_fatal("cannot daemonize");
		}

		daemonmain(db, arg_devname);
	} else {
		_fatal("valid modes are sync, wait, syncwait, fork, and syncfork");
	}
	/* other possibilities: a mode that generates an eventfd
	 * and provides it on fd4 to a subordinate process, or
	 * sends it with SCM_RIGHTS */

	PQfinish(db);
	return 0;
}