oops
[l2tpns.git] / nsctl.c
diff --git a/nsctl.c b/nsctl.c
index 8a8aee0..057edfa 100644 (file)
--- a/nsctl.c
+++ b/nsctl.c
+/* l2tpns plugin control */
+
 #include <stdio.h>
-#include <arpa/inet.h>
+#include <stdlib.h>
+#include <unistd.h>
 #include <errno.h>
-#include <fcntl.h>
-#include <malloc.h>
-#include <netdb.h>
 #include <string.h>
-#include <sys/time.h>
-#include <sys/types.h>
-#include <sys/socket.h>
-#include <netinet/in.h>
+#include <netdb.h>
 #include <signal.h>
-#include <stdarg.h>
-#include <stdlib.h>
-#include <unistd.h>
-#include <time.h>
+
+#include "l2tpns.h"
 #include "control.h"
 
-struct { char *command; int pkt_type; int params; } commands[] = {
-       { "load_plugin", PKT_LOAD_PLUGIN, 1 },
-       { "unload_plugin", PKT_UNLOAD_PLUGIN, 1 },
-       { "garden", PKT_GARDEN, 1 },
-       { "ungarden", PKT_UNGARDEN, 1 },
+struct {
+    char *command;
+    char *usage;
+    int action;
+} builtins[] = {
+    { "load_plugin", " PLUGIN                          Load named plugin",             NSCTL_REQ_LOAD },
+    { "unload_plugin", " PLUGIN                        Unload named plugin",           NSCTL_REQ_UNLOAD },
+    { "help", "                                        List available commands",       NSCTL_REQ_HELP },
+    { 0 }
 };
 
-char *dest_host = NULL;
-unsigned int dest_port = 1702;
-int udpfd;
+static int debug = 0;
+static int timeout = 2; // 2 seconds
+static char *me;
+
+#define USAGE() fprintf(stderr, "Usage: %s [-d] [-h HOST[:PORT]] [-t TIMEOUT] COMMAND [ARG ...]\n", me)
+
+static struct nsctl *request(char *host, int port, int type, int argc, char *argv[]);
 
 int main(int argc, char *argv[])
 {
-       int len = 0;
-       int dest_ip = 0;
-       int pkt_type = 0;
-       char *packet = NULL;
-       int i;
+    int req_type = 0;
+    char *host = 0;
+    int port;
+    int i;
+    char *p;
+    struct nsctl *res;
+
+    if ((p = strrchr((me = argv[0]), '/')))
+       me = p + 1;
+
+    opterr = 0;
+    while ((i = getopt(argc, argv, "dh:t:")) != -1)
+       switch (i)
+       {
+       case 'd':
+           debug++;
+           break;
 
-       setbuf(stdout, NULL);
+       case 'h':
+           host = optarg;
+           break;
 
-       if (argc < 3)
-       {
-               printf("Usage: %s <host> <command> [args...]\n", argv[0]);
-               return 1;
+       case 't':
+           timeout = atoi(optarg);
+           break;
+
+       default:
+           USAGE();
+           return EXIT_FAILURE;
        }
 
-       dest_host = strdup(argv[1]);
+    argc -= optind;
+    argv += optind;
 
-       {
-               // Init socket
-               int on = 1;
-               struct sockaddr_in addr;
-
-               memset(&addr, 0, sizeof(addr));
-               addr.sin_family = AF_INET;
-               addr.sin_port = htons(1703);
-               udpfd = socket(AF_INET, SOCK_DGRAM, 17);
-               setsockopt(udpfd, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on));
-               if (bind(udpfd, (void *) &addr, sizeof(addr)) < 0)
-               {
-                       perror("bind");
-                       return(1);
-               }
-       }
+    if (argc < 1 || !argv[0][0])
+    {
+       USAGE();
+       return EXIT_FAILURE;
+    }
 
+    if (!host)
+       host = "127.0.0.1";
+
+    if ((p = strchr(host, ':')))
+    {
+       port = atoi(p + 1);
+       if (!port)
        {
-               struct hostent *h = gethostbyname(dest_host);
-               if (h) dest_ip = ntohl(*(unsigned int *)h->h_addr);
-               if (!dest_ip) dest_ip = ntohl(inet_addr(dest_host));
-               if (!dest_ip)
-               {
-                       printf("Can't resolve \"%s\"\n", dest_host);
-                       return 0;
-               }
+           fprintf(stderr, "%s: invalid port `%s'\n", me, p + 1);
+           return EXIT_FAILURE;
        }
 
-       if (!(packet = calloc(1400, 1)))
+       *p = 0;
+    }
+    else
+    {
+       port = NSCTL_PORT;
+    }
+
+    for (i = 0; !req_type && builtins[i].command; i++)
+       if (!strcmp(argv[0], builtins[i].command))
+           req_type = builtins[i].action;
+
+    if (req_type == NSCTL_REQ_HELP)
+    {
+       printf("Available commands:\n");
+       for (i = 0; builtins[i].command; i++)
+           printf("  %s%s\n", builtins[i].command, builtins[i].usage);
+    }
+
+    if (req_type)
+    {
+       argc--;
+       argv++;
+    }
+    else
+    {
+       req_type = NSCTL_REQ_CONTROL;
+    }
+
+    if ((res = request(host, port, req_type, argc, argv)))
+    {
+       FILE *stream = stderr;
+       int status = EXIT_FAILURE;
+
+       if (res->type == NSCTL_RES_OK)
        {
-               perror("calloc");
-               return(1);
+           stream = stdout;
+           status = EXIT_SUCCESS;
        }
 
-       srand(time(NULL));
+       for (i = 0; i < res->argc; i++)
+           fprintf(stream, "%s\n", res->argv[i]);
 
-       // Deal with command & params
-       for (i = 0; i < (sizeof(commands) / sizeof(commands[0])); i++)
-       {
-               if (strcasecmp(commands[i].command, argv[2]) == 0)
-               {
-                       int p;
-                       pkt_type = commands[i].pkt_type;
-                       len = new_packet(pkt_type, packet);
-                       if (argc < (commands[i].params + 3))
-                       {
-                               printf("Not enough parameters for %s\n", argv[2]);
-                               return 1;
-                       }
-                       for (p = 0; p < commands[i].params; p++)
-                       {
-                               strncpy((packet + len), argv[p + 3], 1400 - len);
-                               len += strlen(argv[p + 3]) + 1;
-                       }
-                       break;
-               }
-       }
-       if (!pkt_type)
-       {
-               printf("Unknown command\n");
-               return 1;
-       }
+       return status;
+    }
 
-       send_packet(udpfd, dest_ip, dest_port, packet, len);
+    return EXIT_FAILURE;
+}
 
+static void sigalrm_handler(int sig) { }
+
+static struct nsctl *request(char *host, int port, int type, int argc, char *argv[])
+{
+    static struct nsctl res;
+    struct sockaddr_in peer;
+    socklen_t len = sizeof(peer);
+    struct hostent *h = gethostbyname(host);
+    int fd;
+    char buf[NSCTL_MAX_PKT_SZ];
+    int sz;
+    char *err;
+
+    if (!h || h->h_addrtype != AF_INET)
+    {
+       fprintf(stderr, "%s: invalid host `%s'\n", me, host);
+       return 0;
+    }
+
+    if ((fd = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) < 0)
+    {
+       fprintf(stderr, "%s: can't create udp socket (%s)\n", me, strerror(errno));
+       return 0;
+    }
+
+    memset(&peer, 0, len);
+    peer.sin_family = AF_INET;
+    peer.sin_port = htons(port);
+    memcpy(&peer.sin_addr.s_addr, h->h_addr, sizeof(peer.sin_addr.s_addr));
+
+    if (connect(fd, (struct sockaddr *) &peer, sizeof(peer)) < 0)
+    {
+       fprintf(stderr, "%s: udp connect failed (%s)\n", me, strerror(errno));
+       return 0;
+    }
+
+    if ((sz = pack_control(buf, sizeof(buf), type, argc, argv)) < 0)
+    {
+       fprintf(stderr, "%s: error packing request\n", me);
+       return 0;
+    }
+
+    if (debug)
+    {
+       struct nsctl req;
+       if (unpack_control(&req, buf, sz) == type)
        {
-               int n;
-               fd_set r;
-               struct timeval timeout;
-
-               FD_ZERO(&r);
-               FD_SET(udpfd, &r);
-               timeout.tv_sec = 1;
-               timeout.tv_usec = 0;
-
-               n = select(udpfd + 1, &r, 0, 0, &timeout);
-               if (n <= 0)
-               {
-                       printf("Timeout waiting for packet\n");
-                       return 0;
-               }
+           fprintf(stderr, "Sending ");
+           dump_control(&req, stderr);
        }
-       if ((len = read_packet(udpfd, packet)))
+    }
+
+    if (send(fd, buf, sz, 0) < 0)
+    {
+       fprintf(stderr, "%s: error sending request (%s)\n", me, strerror(errno));
+       return 0;
+    }
+
+    /* set timer */
+    if (timeout)
+    {
+       struct sigaction alrm;
+       alrm.sa_handler = sigalrm_handler;
+       sigemptyset(&alrm.sa_mask);
+       alrm.sa_flags = 0;
+
+       sigaction(SIGALRM, &alrm, 0);
+       alarm(timeout);
+    }
+
+    sz = recv(fd, buf, sizeof(buf), 0);
+    alarm(0);
+
+    if (sz < 0)
+    {
+       fprintf(stderr, "%s: error receiving response (%s)\n", me,
+           errno == EINTR ? "timed out" : strerror(errno));
+
+       return 0;
+    }
+
+    if ((type = unpack_control(&res, buf, sz)) > 0 && type & NSCTL_RESPONSE)
+    {
+       if (debug)
        {
-               printf("Received ");
-               dump_packet(packet, stdout);
+           fprintf(stderr, "Received ");
+           dump_control(&res, stderr);
        }
 
-       return 0;
-}
+       return &res;
+    }
 
+    err = "unknown error";
+    switch (type)
+    {
+    case NSCTL_ERR_SHORT:  err = "short packet"; break;
+    case NSCTL_ERR_LONG:   err = "extra data";   break;
+    case NSCTL_ERR_MAGIC:  err = "bad magic";    break;
+    case NSCTL_ERR_TYPE:   err = "invalid type"; break;
+    }
+
+    fprintf(stderr, "%s: %s\n", me, err);
+    return 0;
+}