diff --git a/src/smartdns.c b/src/smartdns.c index dd1132c..8838fe0 100644 --- a/src/smartdns.c +++ b/src/smartdns.c @@ -47,6 +47,7 @@ #include #include #include +#include #include #include @@ -153,6 +154,7 @@ static void _help(void) " -f run foreground.\n" " -c [conf] config file.\n" " -p [pid] pid file path, '-' means don't create pid file.\n" + " -R restart smartdns when crash.\n" " -S ignore segment fault signal.\n" " -x verbose screen.\n" " -v display version.\n" @@ -771,6 +773,56 @@ static void _smartdns_early_log(struct tlog_loginfo *loginfo, const char *format syslog(sys_log_level, "%s", log_buf); } +static int _smartdns_child_pid = 0; + +static void _smartdns_run_as_init_sig(int sig) +{ + if (_smartdns_child_pid > 0) { + kill(_smartdns_child_pid, SIGTERM); + waitpid(_smartdns_child_pid, NULL, 0); + } + + _exit(0); +} + +static int _smartdns_run_as_init(int restart_when_crash) +{ + pid_t pid; + + if (restart_when_crash == 0) { + return 0; + } + + pid = getpid(); + setpgid(pid, pid); + +restart: + pid = fork(); + if (pid < 0) { + fprintf(stderr, "fork failed, %s\n", strerror(errno)); + return -1; + } else if (pid == 0) { + return 0; + } + + _smartdns_child_pid = pid; + + signal(SIGTERM, _smartdns_run_as_init_sig); + + while (true) { + pid = waitpid(-1, NULL, 0); + if (pid == _smartdns_child_pid) { + goto restart; + } + + if (pid < 0) { + sleep(1); + } + } + + return -1; +} + #ifdef TEST static smartdns_post_func _smartdns_post = NULL; @@ -815,6 +867,7 @@ int main(int argc, char *argv[]) char pid_file[MAX_LINE_LEN]; int is_pid_file_set = 0; int signal_ignore = 0; + int restart_when_crash = getpid() == 1 ? 1 : 0; sigset_t empty_sigblock; struct stat sb; @@ -834,7 +887,7 @@ int main(int argc, char *argv[]) sigprocmask(SIG_SETMASK, &empty_sigblock, NULL); smartdns_close_allfds(); - while ((opt = getopt_long(argc, argv, "fhc:p:SvxN:", long_options, NULL)) != -1) { + while ((opt = getopt_long(argc, argv, "fhc:p:SvxN:R", long_options, NULL)) != -1) { switch (opt) { case 'f': is_run_as_daemon = 0; @@ -850,6 +903,9 @@ int main(int argc, char *argv[]) is_pid_file_set = 1; } break; + case 'R': + restart_when_crash = 1; + break; case 'S': signal_ignore = 1; break; @@ -876,6 +932,10 @@ int main(int argc, char *argv[]) } } + if (_smartdns_run_as_init(restart_when_crash) != 0) { + return 1; + } + srand(time(NULL)); tlog_reg_early_printf_callback(_smartdns_early_log);