fix protocol-reject
[l2tpns.git] / l2tpns.c
index 80290dd..54b2aff 100644 (file)
--- a/l2tpns.c
+++ b/l2tpns.c
@@ -4,7 +4,7 @@
 // Copyright (c) 2002 FireBrick (Andrews & Arnold Ltd / Watchfront Ltd) - GPL licenced
 // vim: sw=8 ts=8
 
 // Copyright (c) 2002 FireBrick (Andrews & Arnold Ltd / Watchfront Ltd) - GPL licenced
 // vim: sw=8 ts=8
 
-char const *cvs_id_l2tpns = "$Id: l2tpns.c,v 1.133 2005/09/16 05:04:29 bodea Exp $";
+char const *cvs_id_l2tpns = "$Id: l2tpns.c,v 1.139 2005/09/16 11:46:48 bodea Exp $";
 
 #include <arpa/inet.h>
 #include <assert.h>
 
 #include <arpa/inet.h>
 #include <assert.h>
@@ -988,7 +988,10 @@ void adjust_tcp_mss(sessionidt s, tunnelidt t, uint8_t *buf, int len, uint8_t *t
 {
        int d = (tcp[12] >> 4) * 4;
        uint8_t *mss = 0;
 {
        int d = (tcp[12] >> 4) * 4;
        uint8_t *mss = 0;
+       uint8_t *opts;
        uint8_t *data;
        uint8_t *data;
+       uint16_t orig;
+       uint32_t sum;
 
        if ((tcp[13] & 0x3f) & ~(TCP_FLAG_SYN|TCP_FLAG_ACK)) // only want SYN and SYN,ACK
                return;
 
        if ((tcp[13] & 0x3f) & ~(TCP_FLAG_SYN|TCP_FLAG_ACK)) // only want SYN and SYN,ACK
                return;
@@ -996,35 +999,43 @@ void adjust_tcp_mss(sessionidt s, tunnelidt t, uint8_t *buf, int len, uint8_t *t
        if (tcp + d > buf + len) // short?
                return;
 
        if (tcp + d > buf + len) // short?
                return;
 
+       opts = tcp + 20;
        data = tcp + d;
        data = tcp + d;
-       tcp += 20;
 
 
-       while (tcp < data)
+       while (opts < data)
        {
        {
-               if (*tcp == 2 && tcp[1] == 4) // mss option (2), length 4
+               if (*opts == 2 && opts[1] == 4) // mss option (2), length 4
                {
                {
-                       mss = tcp + 2;
+                       mss = opts + 2;
                        if (mss + 2 > data) return; // short?
                        break;
                }
 
                        if (mss + 2 > data) return; // short?
                        break;
                }
 
-               if (*tcp == 0) return; // end of options
-               if (*tcp == 1 || !tcp[1]) // no op (one byte), or no length (prevent loop)
-                       tcp++;
+               if (*opts == 0) return; // end of options
+               if (*opts == 1 || !opts[1]) // no op (one byte), or no length (prevent loop)
+                       opts++;
                else
                else
-                       tcp += tcp[1]; // skip over option
+                       opts += opts[1]; // skip over option
        }
 
        if (!mss) return; // not found
        }
 
        if (!mss) return; // not found
-       if (ntohl(*(uint16_t *) mss) <= MSS) return; // mss OK
+       orig = ntohs(*(uint16_t *) mss);
 
 
-       LOG(5, s, t, "TCP: %s:%u -> %s:%u SYN%s, adjusted mss from %u to %u\n",
-               fmtaddr(*(in_addr_t *)(buf + 12), 0), *(uint16_t *)tcp,
-               fmtaddr(*(in_addr_t *)(buf + 16), 1), *(uint16_t *)(tcp + 2),
-               (tcp[13] & TCP_FLAG_ACK) ? ",ACK" : "",
-               ntohl(*(uint16_t *) mss), MSS);
+       if (orig <= MSS) return; // mss OK
 
 
-       // FIXME
+       LOG(5, s, t, "TCP: %s:%u -> %s:%u SYN%s: adjusted mss from %u to %u\n",
+               fmtaddr(*(in_addr_t *) (buf + 12), 0), ntohs(*(uint16_t *) tcp),
+               fmtaddr(*(in_addr_t *) (buf + 16), 1), ntohs(*(uint16_t *) (tcp + 2)),
+               (tcp[13] & TCP_FLAG_ACK) ? ",ACK" : "", orig, MSS);
+
+       // set mss
+       *(int16_t *) mss = htons(MSS);
+
+       // adjust checksum (see rfc1141)
+       sum = orig + (~MSS & 0xffff);
+       sum += ntohs(*(uint16_t *) (tcp + 16));
+       sum = (sum & 0xffff) + (sum >> 16);
+       *(uint16_t *) (tcp + 16) = htons(sum);
 }
 
 // process outgoing (to tunnel) IP
 }
 
 // process outgoing (to tunnel) IP
@@ -2597,10 +2608,10 @@ void processudp(uint8_t *buf, int len, struct sockaddr_in *addr)
                        l += 6;
                        if (l > mru) l = mru;
 
                        l += 6;
                        if (l > mru) l = mru;
 
-                       q = makeppp(buf, sizeof(buf), 0, 0, s, t, proto);
+                       q = makeppp(buf, sizeof(buf), 0, 0, s, t, PPPLCP);
                        if (!q) return;
 
                        if (!q) return;
 
-                       *q = CodeRej;
+                       *q = ProtocolRej;
                        *(q + 1) = ++sess_local[s].lcp_ident;
                        *(uint16_t *)(q + 2) = l;
                        *(uint16_t *)(q + 4) = htons(proto);
                        *(q + 1) = ++sess_local[s].lcp_ident;
                        *(uint16_t *)(q + 2) = l;
                        *(uint16_t *)(q + 4) = htons(proto);