util  wgsync.c at [81321a2c01]

File wgsync/src/wgsync.c artifact 40df359611 part of check-in 81321a2c01


#include "def.h"
#include "pqp.h"

/* libc */
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

/* posix */
#include <netinet/in.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netdb.h>

/* libs */
#include <wireguard.h>

#include <libpq-fe.h>


size_t dumpEndpoint(char* d, const wg_endpoint* const e) {
	const struct sockaddr* addr;
	size_t len;
	switch(e->addr.sa_family) {
		case AF_INET: addr = (void*)&(e->addr4); len = sizeof e->addr4; break;
		case AF_INET6: addr = (void*)&(e->addr6); len = sizeof e->addr6; break;
		case 0: strcpy(d, "<endpoint unset>"); return 16;
		default: strcpy(d, "<bad endpoint>"); return 14;
	}
	char bip[256], bsrv[16];
	getnameinfo(addr, len,
			bip,  sizeof bip,
			bsrv, sizeof bsrv,
			NI_NUMERICHOST | NI_NUMERICSERV);
	return sprintf(d, "%s:%s", bip, bsrv);
}

size_t dumpAllowedIP(char* d, const wg_allowedip* aip) {
		union {
			struct sockaddr_in  ip4;
			struct sockaddr_in6 ip6;
		} kinds;
		size_t len;
		switch(aip->family) {
			case AF_INET: {
				kinds.ip4 = (struct sockaddr_in) {
					.sin_family = AF_INET,
					.sin_port = 0,
					.sin_addr = aip->ip4,
				};
				len = sizeof kinds.ip4;
			break;}
			case AF_INET6: {
				kinds.ip6 = (struct sockaddr_in6) {
					.sin6_family = AF_INET6,
					.sin6_port = 0,
					.sin6_flowinfo = 0,
					.sin6_scope_id = 0,
					.sin6_addr = aip->ip6,
				};
				len = sizeof kinds.ip6;
			break;}
			case 0: strcpy(d, "<no IP>"); return 7;
			default: strcpy(d, "<non-IP address>"); return 16;
		}
		char bip[256], bsrv[2];
		getnameinfo((void*)&kinds, len,
				bip,  sizeof bip,
				bsrv, sizeof bsrv,
				NI_NUMERICHOST | NI_NUMERICSERV);
		return sprintf(d, "%s/%u", bip, aip->cidr);
}

bool compare_allowedip
(	const wg_allowedip* const a,
	const wg_allowedip* const b
) {
	if(a -> family != b -> family) return false;
	if(a -> cidr != b -> cidr) return false;
	switch(a->family) {
		case AF_INET:
			if(a -> ip4.s_addr != b -> ip4.s_addr) return false;
			break;
		case AF_INET6:
			if(memcmp(a -> ip6.s6_addr, b -> ip6.s6_addr, sizeof(a->ip6.s6_addr)) != 0)
				return false;
			break;
	}
	return true;
}

wg_allowedip
inet_to_allowedip(const char* data) {
	pqp_sockstore ss;
	if(!pqp_inet_read(data, &ss.sock))
		_fatal("bad IP value in database");

	wg_allowedip wgip = {
		.family = ss.sock.sa_family,
		.next_allowedip = null,
	};
	switch(ss.sock.sa_family) {
		case AF_INET:
			wgip.cidr = 32;
			wgip.ip4 = ss.sock_in.sin_addr;
			break;
		case AF_INET6:
			wgip.cidr = 128;
			wgip.ip6 = ss.sock_in6.sin6_addr;
			break;
		default: _fatal("unhandled address family");
	}
	return wgip;
}

void wgd_free_peer(wg_peer* peer) {
	wg_allowedip *allowedip, *na;
	/* from ext/wglib/wireguard.c:1486 */
	for (
		allowedip = peer->first_allowedip,
			na = allowedip ? allowedip->next_allowedip : NULL; 
		allowedip;
		allowedip = na,
			na = allowedip ? allowedip->next_allowedip : NULL
	) free(allowedip);
	/* end import */
	free(peer);
}

/* linked list manipulation routines */

#define _ll_rec peer
#define _ll_box wg_device
#define _ll_obj wg_peer
#define _ll_iter wg_for_each_peer
#define _ll_ns wgd
#include "list.h"

#define _ll_rec allowedip
#define _ll_box wg_peer
#define _ll_obj wg_allowedip
#define _ll_iter wg_for_each_allowedip
#define _ll_ns wgd_peer
#include "list.h"

#if 0
void wgd_drop_peer(wg_device* dev, wg_peer* peer) {
	if(dev -> first_peer == peer) {
		if(dev -> last_peer == peer) {
			dev -> first_peer = dev -> last_peer = null;
		} else {
			dev -> first_peer = peer -> next_peer;
		}
	} else {
		wg_peer* p;
		if(dev -> last_peer == peer) {
			wg_for_each_peer(dev, p) {
				if(p->next_peer == peer) {
					dev -> last_peer = p;
					p->next_peer = null;
					goto found1;
				}
			}
			_fatal("BUG in last peer deletion routine");
			found1 :;
		} else /* in the middle */ {
			wg_for_each_peer(dev, p) {
				if(p->next_peer == peer) {
					p->next_peer = peer -> next_peer;
					goto found2;
				}
			}
			_fatal("BUG in peer deletion routine");
			found2 :;
		}
	}
	wgd_free_peer(peer);
}
void wgd_peer_drop_ip(wg_peer* peer, wg_allowedip* ip) {
	if(peer -> first_allowedip == ip) {
		if(peer -> last_allowedip == ip) {
			peer -> first_allowedip = peer -> last_allowedip = null;
		} else {
			peer -> first_allowedip = peer -> next_allowedip;
		}
	} else {
		wg_allowedip* a;
		if(peer -> last_allowedip == ip) {
			wg_for_each_allowedip(peer, a) {
				if(a->next_allowedip == ip) {
					peer -> last_allowedip = a;
					a->next_allowedip = null;
					goto found1;
				}
			}
			_fatal("BUG in last aIP deletion routine");
			found1 :;
		} else /* in the middle */ {
			wg_for_each_allowedip(peer, a) {
				if(a->next_allowedip == ip) {
					a->next_allowedip = ip -> next_allowedip;
					goto found2;
				}
			}
			_fatal("BUG in aIP deletion routine");
			found2 :;
		}
	}
	free(ip);
}
#endif

void syncauth(PGconn* db, const char* wgdev) {
	wg_device* wg;
	if (wg_get_device(&wg, wgdev))
		_fatal("no wireguard device by that name");

	bool dirty = false;
	size_t peerc = 0;
	{ wg_peer* p; wg_for_each_peer(wg, p) ++ peerc; };

	bool valid_peers [peerc];
	_zero(valid_peers);

	PGresult* rows = PQexecPrepared(db, "get_hosts",
			0, null, null, null, 1);
	if(!(rows && PQresultStatus(rows) == PGRES_TUPLES_OK))
		_fatal(PQerrorMessage(db));

	size_t rowc = PQntuples(rows);
	for(size_t i = 0; i < rowc; ++i) {
		const char* key_b64 = PQgetvalue(rows, i, 0);

		const char* aryraw = PQgetvalue(rows, i, 1);
		pqp_array* ips = pqp_array_read(aryraw);
		_dbgf("DB has peer %s", key_b64);
		if(ips->ty != pq_array_inet)
			_fatal("incorrect array type returned from DB");

		wg_key key;
		if (wg_key_from_base64(key, key_b64) < 0) {
			_warnf("invalid key in database: %s", key_b64);
			continue;
		}

		wg_peer* found = null;
		{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
			if(memcmp(p->public_key, key, sizeof key) == 0) {
				_dbgf("validating peer %s", key_b64);
				valid_peers[j] = true;
				found = p;
				break;
			}
		++j;}}

		if (found) {
			/* compare and update IPs if necessary */
			bool goodIPs [ips -> sz]; _zero(goodIPs);
			/* extant IPs that are not marked good by the
			 * end of the following loop must be deleted
			 * from memory */
			size_t goodIPc = 0;
			for (size_t j = 0; j < ips -> sz; ++j) {
				char inetstr[256];
				wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
				dumpAllowedIP(inetstr, &aip);
				_dbgf("IP PG%zu :: %s", j, inetstr);

				size_t l = 0;
				wg_allowedip* wgip;
				bool foundIP = false;
				wg_for_each_allowedip(found, wgip) {
					if (compare_allowedip(&aip, wgip)) {
						++goodIPc; goodIPs[l] = true;
						foundIP = true;
					}
				++l;}
				
				if(!foundIP) {
					/* this IP hasn't been loaded into the
					 * kernel yet; upload it now */
					_infof("inserting IP PG%zu %s", j, inetstr);
					dirty = true;
				}
			}

			if(goodIPc < ips -> sz) {
				size_t l = 0;
				wg_allowedip* wgip;
				wg_for_each_allowedip(found, wgip) {
					char inetstr[256];
					dumpAllowedIP(inetstr, wgip);
					_dbgf("IP WG%zu :: %s", l, inetstr);
					if(!goodIPs[l]) {
						/* this IP is stale, delete it */
						_infof("deleting IP WG%zu %s", l, inetstr);
						dirty = true;
					}
				++l;}
			}
		} else {
			_infof("inserting key %s", key_b64);
			dirty = true;
			/* install new peer */
			for (size_t j = 0; j < ips -> sz; ++j) {
				char inetstr[256];
				wg_allowedip aip = inet_to_allowedip(ips -> elts[j].data);
				dumpAllowedIP(inetstr, &aip);
				_dbgf("new IP %zu :: %s", j, inetstr);
			}
		}

		free(ips);
	}
	{ size_t i=0; wg_peer* p; wg_for_each_peer(wg, p) {
		if(valid_peers[i] == false) {
			char b64 [128];
			wg_key_to_base64(b64, p->public_key);
			_infof("dropping peer %s", b64);
			wgd_drop_peer(wg, p);
			dirty = true;
		}
	++i;}}

	_dbg("final peer list:");
	{ size_t j=0; wg_peer* p; wg_for_each_peer(wg, p) {
		char b64 [128];
		wg_key_to_base64(b64, p->public_key);
		_dbgf("P%zu :: %s", j, b64);
	++j;}}
	
	if(dirty) wg_set_device(wg);

	PQclear(rows);
}

int main(int argc, char** argv) {
	setvbuf(stderr, null, _IONBF, 0);
	if (argc < 3) {
		_fatal("missing device name");
	}

	const char* arg_mode = argv[1];
	const char* arg_devname = argv[2];

	/* mostly for the sake of debugging, allow the
	 * binary to be run from sudo without losing
	 * postgres peer credentials */
	if(geteuid() == 0) {
		char* suid = getenv("SUDO_UID");
		char* susr = getenv("SUDO_USER");
		if(suid) seteuid(atoi(suid));
		if(susr) setenv("USER",getenv("SUDO_USER"), 1);
	}

	PGconn* db = PQconnectdb("dbname=domain");
	if(PQstatus(db) != CONNECTION_OK) 
		_fatal(PQerrorMessage(db));

	PGresult* q_get_hosts = PQprepare(db, "get_hosts",
		"select h.ref, array_remove(array_agg(wgv4::inet)"
		                        "|| array_agg(wgv6::inet), null)"
			"from ns, hostref h "
			"where ns.host = h.host and kind = 'pubkey' "
			" group by h.host, h.ref;", 0, null);
		/*"select ns.wgv4::inet, ns.wgv6::inet, h.ref from ns "
			"right join hostref h "
				"on h.host = ns.host "
			"where h.kind = 'pubkey';"*/
	if(!(q_get_hosts && PQresultStatus(q_get_hosts) == PGRES_COMMAND_OK))
		_fatal(PQerrorMessage(db));
	PQclear(q_get_hosts);

	/* we're going to interact with WG now;
	 * get our superpowers back if we lost them */
	{uid_t svuid;
	getresuid(null, null, &svuid);
	if (svuid == 0) setuid(0);}

	if(strcmp(arg_mode, "sync") == 0) {
		syncauth(db, arg_devname);
	} else if(strcmp(arg_mode, "wait") == 0) {
		/* foreground daemon */
	} else if(strcmp(arg_mode, "fork") == 0) {
		/* background daemon */
	} else {
		_fatal("valid modes are sync, wait, and fork");
	}
	/* other possibilities: a mode that generates an eventfd
	 * and provides it on fd4 to a subordinate process, or
	 * sends it with SCM_RIGHTS */

	PQfinish(db);
	return 0;
}