simplify AVP unhiding code
[l2tpns.git] / l2tpns.c
index fb02e1c..4c047e1 100644 (file)
--- a/l2tpns.c
+++ b/l2tpns.c
@@ -4,7 +4,7 @@
 // Copyright (c) 2002 FireBrick (Andrews & Arnold Ltd / Watchfront Ltd) - GPL licenced
 // vim: sw=8 ts=8
 
 // Copyright (c) 2002 FireBrick (Andrews & Arnold Ltd / Watchfront Ltd) - GPL licenced
 // vim: sw=8 ts=8
 
-char const *cvs_id_l2tpns = "$Id: l2tpns.c,v 1.84 2005-02-14 06:58:39 bodea Exp $";
+char const *cvs_id_l2tpns = "$Id: l2tpns.c,v 1.85 2005-03-10 03:08:08 bodea Exp $";
 
 #include <arpa/inet.h>
 #include <assert.h>
 
 #include <arpa/inet.h>
 #include <assert.h>
@@ -187,7 +187,7 @@ static int remove_plugin(char *plugin_name);
 static void plugins_done(void);
 static void processcontrol(uint8_t *buf, int len, struct sockaddr_in *addr, int alen);
 static tunnelidt new_tunnel(void);
 static void plugins_done(void);
 static void processcontrol(uint8_t *buf, int len, struct sockaddr_in *addr, int alen);
 static tunnelidt new_tunnel(void);
-static int unhide_avp(uint8_t *avp, tunnelidt t, sessionidt s, uint16_t length);
+static void unhide_value(uint8_t *value, size_t len, uint16_t type, uint8_t *vector, size_t vec_len);
 
 // return internal time (10ths since process startup)
 static clockt now(void)
 
 // return internal time (10ths since process startup)
 static clockt now(void)
@@ -1761,7 +1761,7 @@ void processudp(uint8_t * buf, int len, struct sockaddr_in *addr)
        {                          // control
                uint16_t message = 0xFFFF;      // message type
                uint8_t fatal = 0;
        {                          // control
                uint16_t message = 0xFFFF;      // message type
                uint8_t fatal = 0;
-               uint8_t mandatorymessage = 0;
+               uint8_t mandatory = 0;
                uint8_t chap = 0;               // if CHAP being used
                uint16_t asession = 0;          // assigned session
                uint32_t amagic = 0;            // magic number
                uint8_t chap = 0;               // if CHAP being used
                uint16_t asession = 0;          // assigned session
                uint32_t amagic = 0;            // magic number
@@ -1777,7 +1777,10 @@ void processudp(uint8_t * buf, int len, struct sockaddr_in *addr)
                        return;
                }
 
                        return;
                }
 
-               if ((*buf & 0xCA) != 0xC8)
+               // control messages must have bits 0x80|0x40|0x08
+               // (type, length and sequence) set, and bits 0x02|0x01
+               // (offset and priority) clear
+               if ((*buf & 0xCB) != 0xC8)
                {
                        LOG(1, s, t, "Bad control header %02X\n", *buf);
                        STAT(tunnel_rx_errors);
                {
                        LOG(1, s, t, "Bad control header %02X\n", *buf);
                        STAT(tunnel_rx_errors);
@@ -1888,23 +1891,42 @@ void processudp(uint8_t * buf, int len, struct sockaddr_in *addr)
                if (l)
                {                     // if not a null message
                        // process AVPs
                if (l)
                {                     // if not a null message
                        // process AVPs
-                       while (l && !(fatal & 0x80))
+                       while (l && !(fatal & 0x80)) // 0x80 = mandatory AVP
                        {
                                uint16_t n = (ntohs(*(uint16_t *) p) & 0x3FF);
                                uint8_t *b = p;
                                uint8_t flags = *p;
                                uint16_t mtype;
                        {
                                uint16_t n = (ntohs(*(uint16_t *) p) & 0x3FF);
                                uint8_t *b = p;
                                uint8_t flags = *p;
                                uint16_t mtype;
-                               p += n;       // next
-                               if (l < n)
+                               if (n > l)
                                {
                                        LOG(1, s, t, "Invalid length in AVP\n");
                                        STAT(tunnel_rx_errors);
                                {
                                        LOG(1, s, t, "Invalid length in AVP\n");
                                        STAT(tunnel_rx_errors);
-                                       fatal = flags;
                                        return;
                                }
                                        return;
                                }
+                               p += n;       // next
                                l -= n;
                                l -= n;
+                               if (flags & 0x3C) // reserved bits, should be clear
+                               {
+                                       LOG(1, s, t, "Unrecognised AVP flags %02X\n", *b);
+                                       fatal = flags;
+                                       continue; // next
+                               }
+                               b += 2;
+                               if (*(uint16_t *) (b))
+                               {
+                                       LOG(2, s, t, "Unknown AVP vendor %d\n", ntohs(*(uint16_t *) (b)));
+                                       fatal = flags;
+                                       continue; // next
+                               }
+                               b += 2;
+                               mtype = ntohs(*(uint16_t *) (b));
+                               b += 2;
+                               n -= 6;
+
                                if (flags & 0x40)
                                {
                                if (flags & 0x40)
                                {
+                                       uint16_t orig_len;
+
                                        // handle hidden AVPs
                                        if (!*config->l2tpsecret)
                                        {
                                        // handle hidden AVPs
                                        if (!*config->l2tpsecret)
                                        {
@@ -1918,40 +1940,36 @@ void processudp(uint8_t * buf, int len, struct sockaddr_in *addr)
                                                fatal = flags;
                                                continue;
                                        }
                                                fatal = flags;
                                                continue;
                                        }
+                                       if (n < 8)
+                                       {
+                                               LOG(2, s, t, "Short hidden AVP.\n");
+                                               fatal = flags;
+                                               continue;
+                                       }
+
                                        LOG(4, s, t, "Hidden AVP\n");
                                        LOG(4, s, t, "Hidden AVP\n");
+
                                        // Unhide the AVP
                                        // Unhide the AVP
-                                       n = unhide_avp(b, t, s, n);
-                                       if (n == 0)
+                                       unhide_value(b, n, mtype, session[s].random_vector, session[s].random_vector_length);
+
+                                       orig_len = ntohs(*(uint16_t *) b);
+                                       if (orig_len > n + 2)
                                        {
                                                fatal = flags;
                                                continue;
                                        }
                                        {
                                                fatal = flags;
                                                continue;
                                        }
+
+                                       b += 2;
+                                       n = orig_len;
                                }
                                }
-                               if (*b & 0x3C)
-                               {
-                                       LOG(1, s, t, "Unrecognised AVP flags %02X\n", *b);
-                                       fatal = flags;
-                                       continue; // next
-                               }
-                               b += 2;
-                               if (*(uint16_t *) (b))
-                               {
-                                       LOG(2, s, t, "Unknown AVP vendor %d\n", ntohs(*(uint16_t *) (b)));
-                                       fatal = flags;
-                                       continue; // next
-                               }
-                               b += 2;
-                               mtype = ntohs(*(uint16_t *) (b));
-                               b += 2;
-                               n -= 6;
 
                                LOG(4, s, t, "   AVP %d (%s) len %d\n", mtype, avp_name(mtype), n);
                                switch (mtype)
                                {
                                case 0:     // message type
                                        message = ntohs(*(uint16_t *) b);
 
                                LOG(4, s, t, "   AVP %d (%s) len %d\n", mtype, avp_name(mtype), n);
                                switch (mtype)
                                {
                                case 0:     // message type
                                        message = ntohs(*(uint16_t *) b);
+                                       mandatory = flags & 0x80;
                                        LOG(4, s, t, "   Message type = %d (%s)\n", *b, l2tp_message_type(message));
                                        LOG(4, s, t, "   Message type = %d (%s)\n", *b, l2tp_message_type(message));
-                                       mandatorymessage = flags;
                                        break;
                                case 1:     // result code
                                        {
                                        break;
                                case 1:     // result code
                                        {
@@ -2276,8 +2294,8 @@ void processudp(uint8_t * buf, int len, struct sockaddr_in *addr)
                                        break;
                                default:
                                        STAT(tunnel_rx_errors);
                                        break;
                                default:
                                        STAT(tunnel_rx_errors);
-                                       if (mandatorymessage & 0x80)
-                                               tunnelshutdown(t, "Unknown message");
+                                       if (mandatory)
+                                               tunnelshutdown(t, "Unknown message type");
                                        else
                                                LOG(1, s, t, "Unknown message type %d\n", message);
                                        break;
                                        else
                                                LOG(1, s, t, "Unknown message type %d\n", message);
                                        break;
@@ -4727,75 +4745,45 @@ int cmd_show_hist_open(struct cli_def *cli, char *command, char **argv, int argc
 
 /* Unhide an avp.
  *
 
 /* Unhide an avp.
  *
- * This unencodes the AVP using the L2TP CHAP secret and the
- * previously stored random vector. It replaces the hidden data with
- * the cleartext data and returns the length of the cleartext data
- * (including the AVP "header" of 6 bytes).
- *
- * Based on code from rp-l2tpd by Roaring Penguin Software Inc.
+ * This unencodes the AVP using the L2TP secret and the previously
+ * stored random vector.  It overwrites the hidden data with the
+ * unhidden AVP subformat.
  */
  */
-static int unhide_avp(uint8_t *avp, tunnelidt t, sessionidt s, uint16_t length)
+static void unhide_value(uint8_t *value, size_t len, uint16_t type, uint8_t *vector, size_t vec_len)
 {
        MD5_CTX ctx;
 {
        MD5_CTX ctx;
-       uint8_t *cursor;
        uint8_t digest[16];
        uint8_t digest[16];
-       uint8_t working_vector[16];
-       uint16_t hidden_length;
-       uint8_t type[2];
-       size_t done, todo;
-       uint8_t *output;
-
-       // Find the AVP type.
-       type[0] = *(avp + 4);
-       type[1] = *(avp + 5);
-
-       // Line up with the hidden data
-       cursor = output = avp + 6;
+       uint8_t *last;
+       size_t d = 0;
 
        // Compute initial pad
        MD5Init(&ctx);
 
        // Compute initial pad
        MD5Init(&ctx);
-       MD5Update(&ctx, type, 2);
+       MD5Update(&ctx, (uint8_t) (type >> 8) & 0xff, 1);
+       MD5Update(&ctx, (uint8_t)  type       & 0xff, 1);
        MD5Update(&ctx, config->l2tpsecret, strlen(config->l2tpsecret));
        MD5Update(&ctx, config->l2tpsecret, strlen(config->l2tpsecret));
-       MD5Update(&ctx, session[s].random_vector, session[s].random_vector_length);
+       MD5Update(&ctx, vector, vec_len);
        MD5Final(digest, &ctx);
 
        MD5Final(digest, &ctx);
 
-       // Get hidden length
-       hidden_length = ((uint16_t) (digest[0] ^ cursor[0])) * 256 + (uint16_t) (digest[1] ^ cursor[1]);
-
-       // Keep these for later use
-       working_vector[0] = *cursor;
-       working_vector[1] = *(cursor + 1);
-       cursor += 2;
-
-       if (hidden_length > length - 8)
-       {
-               LOG(1, s, t, "Hidden length %d too long in AVP of length %d\n", (int) hidden_length, (int) length);
-               return 0;
-       }
+       // pointer to last decoded 16 octets
+       last = value;
 
 
-       /* Decrypt remainder */
-       done = 2;
-       todo = hidden_length;
-       while (todo)
+       while (len > 0)
        {
        {
-               working_vector[done] = *cursor;
-               *output = digest[done] ^ *cursor;
-               ++output;
-               ++cursor;
-               --todo;
-               ++done;
-               if (done == 16 && todo)
+               // calculate a new pad based on the last decoded block
+               if (d >= sizeof(digest))
                {
                {
-                       // Compute new digest
-                       done = 0;
                        MD5Init(&ctx);
                        MD5Update(&ctx, config->l2tpsecret, strlen(config->l2tpsecret));
                        MD5Init(&ctx);
                        MD5Update(&ctx, config->l2tpsecret, strlen(config->l2tpsecret));
-                       MD5Update(&ctx, &working_vector, 16);
+                       MD5Update(&ctx, last, sizeof(digest));
                        MD5Final(digest, &ctx);
                        MD5Final(digest, &ctx);
+
+                       d = 0;
+                       last = value;
                }
                }
-       }
 
 
-       return hidden_length + 6;
+               *value++ ^= digest[d++];
+               len--;
+       }
 }
 
 static int ip_filter_port(ip_filter_portt *p, uint16_t port)
 }
 
 static int ip_filter_port(ip_filter_portt *p, uint16_t port)