util  Diff

Differences From Artifact [da84aa2670]:

To Artifact [804d360053]:


     7      7   #include <string.h>
     8      8   
     9      9   /* posix */
    10     10   #include <netinet/in.h>
    11     11   #include <unistd.h>
    12     12   #include <sys/socket.h>
    13     13   #include <netdb.h>
           14  +#include <poll.h>
           15  +
           16  +#if __linux__
           17  +#	include <sys/signalfd.h>
           18  +#	include <signal.h>
           19  +#endif
    14     20   
    15     21   /* libs */
    16     22   #include <wireguard.h>
    17     23   #include "wglist.h"
    18     24   	/* wireguard uses messy linked lists but doesn't
    19     25   	 * provide any routines for manipulating them;
    20     26   	 * wglist.h fills in the gap */
................................................................................
   340    346   		int e = wg_set_device(wg);
   341    347   		if(e != 0) 
   342    348   			_fatalf("could not set wg device (error %i)", -e);
   343    349   	}
   344    350   
   345    351   	PQclear(rows);
   346    352   }
          353  +
          354  +void daemonmain(PGconn* db, const char* wgdev) {
          355  +	PGresult* subscribe = PQexec(db,
          356  +		"listen sync_vpn;"
          357  +		"listen sync_priv;");
          358  +	if (PQresultStatus(subscribe) != PGRES_COMMAND_OK)
          359  +		_warn("could not subscribe to DB notification channels");
          360  +	PQclear(subscribe);
          361  +
          362  +	int pqfd = PQsocket(db);
          363  +#if __linux__
          364  +	sigset_t sigs;
          365  +	sigemptyset(&sigs);
          366  +	sigaddset(&sigs, SIGHUP);
          367  +	sigaddset(&sigs, SIGTERM);
          368  +	sigaddset(&sigs, SIGINT);
          369  +	sigprocmask(SIG_BLOCK, &sigs, null);
          370  +	int sigfd = signalfd(-1, &sigs, SFD_CLOEXEC);
          371  +#endif
          372  +
          373  +	struct pollfd polls[] = {
          374  +		{ .fd = pqfd, .events = POLLIN, .revents = 0 },
          375  +#if __linux__
          376  +		{ .fd = sigfd, .events = POLLIN, .revents = 0 },
          377  +#endif
          378  +	};
          379  +
          380  +	for (;;) {
          381  +		int p = poll(polls, _sz(polls), -1);
          382  +		if (p > 0) {
          383  +			bool didSync = false;
          384  +			switch (polls[0].revents) {
          385  +				case 0: break;
          386  +				case POLLHUP:
          387  +					_fatal("lost DB connection; terminating");
          388  +				case POLLIN: {
          389  +					PQconsumeInput(db);
          390  +					for (;;) {
          391  +						PGnotify* n = PQnotifies(db);
          392  +						if(n == null) break;
          393  +						if(strcmp(n->relname, "sync_vpn") == 0
          394  +						|| strcmp(n->relname, "sync_priv") == 0) {
          395  +							if(!didSync) {
          396  +								syncauth(db, wgdev);
          397  +								didSync = true;
          398  +							}
          399  +						}
          400  +					}
          401  +				}
          402  +			}
          403  +#if __linux__
          404  +			switch (polls[1].revents) {
          405  +				case 0: break;
          406  +				case POLLIN: {
          407  +					struct signalfd_siginfo si;
          408  +					read(sigfd, &si, sizeof si);
          409  +
          410  +					if(si.ssi_signo == SIGHUP && !didSync) {
          411  +						syncauth(db, wgdev);
          412  +						didSync = true;
          413  +					} else if (si.ssi_signo == SIGTERM || si.ssi_signo == SIGINT) {
          414  +						goto poll_end;
          415  +					}
          416  +				};
          417  +			}
          418  +#endif
          419  +		}
          420  +	}
          421  +
          422  +	poll_end :;
          423  +
          424  +	_info("shutting down");
          425  +#if __linux__
          426  +	close(sigfd);
          427  +#endif
          428  +}
   347    429   
   348    430   int main(int argc, char** argv) {
   349    431   	setvbuf(stderr, null, _IONBF, 0);
   350    432   	if (argc < 3) {
   351    433   		_fatal("missing device name");
   352    434   	}
   353    435   
................................................................................
   386    468   	 * get our superpowers back if we lost them */
   387    469   	{uid_t svuid;
   388    470   	getresuid(null, null, &svuid);
   389    471   	if (svuid == 0) setuid(0);}
   390    472   
   391    473   	if(strcmp(arg_mode, "sync") == 0) {
   392    474   		syncauth(db, arg_devname);
   393         -	} else if(strcmp(arg_mode, "wait") == 0) {
   394         -		/* foreground daemon */
   395         -	} else if(strcmp(arg_mode, "fork") == 0) {
   396         -		/* background daemon */
          475  +	} else if(strcmp(arg_mode, "wait")     == 0 ||
          476  +	          strcmp(arg_mode, "syncwait") == 0 ||
          477  +	          strcmp(arg_mode, "fork")     == 0 ||
          478  +	          strcmp(arg_mode, "syncfork") == 0) {
          479  +
          480  +		if(strncmp(arg_mode, "sync", 4) == 0)
          481  +			syncauth(db, arg_devname);
          482  +
          483  +		/* maybe background daemon */
          484  +		if(strcmp(arg_mode, "fork")     == 0 ||
          485  +	       strcmp(arg_mode, "syncfork") == 0) {
          486  +			if (daemon(1,1) == -1)
          487  +				_fatal("cannot daemonize");
          488  +		}
          489  +
          490  +		daemonmain(db, arg_devname);
   397    491   	} else {
   398         -		_fatal("valid modes are sync, wait, and fork");
          492  +		_fatal("valid modes are sync, wait, syncwait, fork, and syncfork");
   399    493   	}
   400    494   	/* other possibilities: a mode that generates an eventfd
   401    495   	 * and provides it on fd4 to a subordinate process, or
   402    496   	 * sends it with SCM_RIGHTS */
   403    497   
   404    498   	PQfinish(db);
   405    499   	return 0;
   406    500   }