improved load balancing algorithm.
[l2tpns.git] / test / radius.c
1 /* RADIUS authentication load test */
2
3 #define _SVID_SOURCE
4 #define _POSIX_SOURCE
5
6 #include <stdio.h>
7 #include <stdlib.h>
8 #include <stdarg.h>
9 #include <unistd.h>
10 #include <getopt.h>
11 #include <string.h>
12 #include <errno.h>
13 #include <sys/wait.h>
14 #include <sys/mman.h>
15 #include <netinet/in.h>
16 #include <netdb.h>
17 #include <sys/time.h>
18 #include <time.h>
19 #include <fcntl.h>
20 #include <sys/select.h>
21 #include <signal.h>
22 #include "../md5.h"
23
24 extern char *optarg;
25 extern int optind;
26
27 struct user {
28 char *user;
29 char *pass;
30 int flags;
31 #define F_FAKE 1
32 #define F_BAD 2
33 #define F_USED 4
34 char *request;
35 int request_len;
36 struct user *next;
37 };
38
39 typedef uint32_t u32;
40
41 struct user_list {
42 struct user *entry;
43 int attempts;
44 int response;
45 u32 begin;
46 u32 retry;
47 u32 end;
48 };
49
50 struct stats {
51 int total;
52 int out;
53 int in;
54 int err;
55 int ready;
56 };
57
58 enum {
59 AccessRequest = 1,
60 AccessAccept,
61 AccessReject,
62 AccessFail = 99
63 };
64
65 #define USAGE "Usage: %s [-i input] [-n instances] [-f fake] [-b bad] " \
66 "[-l limit] server port secret\n"
67
68 #define MAX_ATTEMPTS 5
69
70 void *xmalloc(size_t size)
71 {
72 void *p = malloc(size);
73 if (!p)
74 {
75 fprintf(stderr, "out of memory allocating %d bytes\n", size);
76 exit(1);
77 }
78
79 return p;
80 }
81
82 char *xstrdup(char *s)
83 {
84 int l = strlen(s);
85 char *p = xmalloc(l + 1);
86 return strcpy(p, s);
87 }
88
89 void *xmmap(size_t size)
90 {
91 void *p = mmap(NULL, size, PROT_READ|PROT_WRITE, MAP_SHARED|MAP_ANONYMOUS, 0, 0);
92
93 if (p == MAP_FAILED)
94 {
95 fprintf(stderr, "out of memory allocating %d shared bytes\n", size);
96 exit(1);
97 }
98
99 return p;
100 }
101
102 void logmsg(char *fmt, ...)
103 {
104 static int new = 1;
105
106 if (new)
107 {
108 static char time_s[] = "YYYY-MM-DD HH:MM:SS ";
109 time_t now = time(NULL);
110
111 strftime(time_s, sizeof(time_s), "%Y-%m-%d %T ", localtime(&now));
112 fputs(time_s, stdout);
113 }
114
115 va_list ap;
116 va_start(ap, fmt);
117 vprintf(fmt, ap);
118 va_end(ap);
119
120 fflush(stdout);
121
122 new = strchr(fmt, '\n') != NULL;
123 }
124
125 void catch(int sig __attribute__ ((unused)) ) {}
126
127 void child(struct user_list *users, int count, int rshift,
128 struct stats *stats, in_addr_t addr, int port, int limit)
129 __attribute__ ((noreturn));
130
131 time_t basetime;
132
133 int main(int argc, char *argv[])
134 {
135 char *input = 0;
136 int instances = 1;
137 int fake = 0;
138 int bad = 0;
139 int limit = 100000;
140 int o;
141
142 while ((o = getopt(argc, argv, "i:n:f:b:l:")) != -1)
143 {
144 switch (o)
145 {
146 case 'i': /* input file */
147 input = optarg;
148 break;
149
150 case 'n': /* parallel instances */
151 instances = atoi(optarg);
152 if (instances < 1 || instances > 32)
153 {
154 fprintf(stderr, "invalid instances value: `%s' (1-32)\n", optarg);
155 return 2;
156 }
157 break;
158
159 case 'f': /* percentage of additional fake users to add */
160 fake = atoi(optarg);
161 if (fake < 1 || fake > 100)
162 {
163 fprintf(stderr, "invalid fake value: `%s' (1-100)\n", optarg);
164 return 2;
165 }
166 break;
167
168 case 'b': /* percentage of users to use incorrect passwords for */
169 bad = atoi(optarg);
170 if (bad < 1 || bad > 100)
171 {
172 fprintf(stderr, "invalid bad value: `%s' (1-100)\n", optarg);
173 return 2;
174 }
175 break;
176
177 case 'l': /* limit number of messages per 1/10 sec */
178 limit = atoi(optarg);
179 if (limit < 1)
180 {
181 fprintf(stderr, "invalid limit value: `%s'\n", optarg);
182 return 2;
183 }
184 break;
185
186 default:
187 fprintf(stderr, USAGE, argv[0]);
188 return 2;
189 }
190 }
191
192 if (argc - optind != 3)
193 {
194 fprintf(stderr, USAGE, argv[0]);
195 return 2;
196 }
197
198 char *server = argv[optind++];
199 char *port_s = argv[optind++];
200 char *secret = argv[optind];
201
202 int port = atoi(port_s);
203 if (port < 1)
204 {
205 fprintf(stderr, "invalid port: `%s'\n", port_s);
206 return 2;
207 }
208
209 in_addr_t server_addr;
210 {
211 struct hostent *h;
212 if (!(h = gethostbyname(server)) || h->h_addrtype != AF_INET)
213 {
214 fprintf(stderr, "invalid server `%s' (%s)\n", server,
215 h ? "no address" : hstrerror(h_errno));
216
217 return 1;
218 }
219
220 memcpy(&server_addr, h->h_addr, sizeof(server_addr));
221 }
222
223 time(&basetime); /* start clock */
224
225 FILE *in = stdin;
226 if (input && !(in = fopen(input, "r")))
227 {
228 fprintf(stderr, "can't open input file `%s' (%s)\n", input,
229 strerror(errno));
230
231 return 1;
232 }
233
234 logmsg("Loading users from %s: ", input ? input : "stdin");
235
236 struct user *users = 0;
237 struct user *u = 0;
238
239 int count = 0;
240 char buf[1024];
241
242 while (fgets(buf, sizeof(buf), in))
243 {
244 count++;
245
246 /* format: username \t password \n */
247 char *p = strchr(buf, '\t');
248 if (!p)
249 {
250 fprintf(stderr, "invalid input line %d (no TAB)\n", count);
251 return 1;
252 }
253
254 *p++ = 0;
255 if (!u)
256 {
257 users = xmalloc(sizeof(struct user));
258 u = users;
259 }
260 else
261 {
262 u->next = xmalloc(sizeof(struct user));
263 u = u->next;
264 }
265
266 u->user = xstrdup(buf);
267 while (*p == '\t')
268 p++;
269
270 char *q = strchr(p, '\n');
271 if (q)
272 *q = 0;
273
274 if (!*p)
275 {
276 fprintf(stderr, "invalid input line %d (no password)\n", count);
277 return 1;
278 }
279
280 u->pass = xstrdup(p);
281 u->flags = 0;
282 u->next = 0;
283 }
284
285 if (input)
286 fclose(in);
287
288 logmsg("%d\n", count);
289
290 if (!count)
291 return 1;
292
293 char *fake_pw = "__fake__";
294 if (fake)
295 {
296 /* add f fake users to make a total of which fake% are bogus */
297 int f = ((count * fake) / (100.0 - fake) + 0.5);
298 char fake_user[] = "__fake_99999999";
299
300 logmsg("Generating %d%% extra fake users: ", fake);
301 for (int i = 0; i < f; i++, count++)
302 {
303 snprintf(fake_user, sizeof(fake_user), "__fake_%d", i);
304 u->next = xmalloc(sizeof(struct user));
305 u = u->next;
306 u->user = xstrdup(fake_user);
307 u->pass = fake_pw;
308 u->flags = F_FAKE;
309 u->next = 0;
310 }
311
312 logmsg("%d\n", f);
313 }
314
315 if (bad)
316 {
317 int b = (count * bad) / 100.0 + 0.5;
318
319 logmsg("Setting %d%% bad passwords: ", bad);
320
321 u = users;
322 for (int i = 0; i < b; i++, u = u->next)
323 {
324 if (u->pass != fake_pw)
325 free(u->pass);
326
327 u->pass = "__bad__";
328 u->flags |= F_BAD;
329 }
330
331 logmsg("%d\n", b);
332 }
333
334 struct user **unsorted = xmalloc(sizeof(struct user) * count);
335
336 u = users;
337 for (int i = 0; i < count; i++, u = u->next)
338 unsorted[i] = u;
339
340 struct user_list *random = xmmap(sizeof(struct user_list) * count);
341 memset(random, 0, sizeof(struct user_list) * count);
342
343 logmsg("Randomising users: ");
344
345 srand(time(NULL) ^ getpid());
346
347 for (int i = 0; i < count; )
348 {
349 int j = 1.0 * count * rand() / RAND_MAX;
350 if (unsorted[j]->flags & F_USED)
351 continue;
352
353 random[i++].entry = unsorted[j];
354 unsorted[j]->flags |= F_USED;
355 }
356
357 logmsg("done\n");
358 logmsg("Building RADIUS queries: ");
359
360 {
361 char pass[128];
362
363 for (u = users; u; u = u->next)
364 {
365 int pw_len = strlen(u->pass);
366 int len = 4 /* code, identifier, length */
367 + 16 /* authenticator */
368 + 2 + strlen(u->user) /* user */
369 + 2 + ((pw_len / 16) + ((pw_len % 16) ? 1 : 0)) * 16;
370 /* encoded password */
371
372 char *p = xmalloc(len);
373 u->request = p;
374 u->request_len = len;
375
376 *p++ = AccessRequest;
377 *p++ = 0; /* identifier set in child */
378 *(uint16_t *) p = htons(len);
379 p += 2;
380
381 /* authenticator */
382 for (int j = 0; j < 16; j++)
383 *p++ = rand();
384
385 *p = 1; /* user name */
386 p[1] = strlen(u->user) + 2;
387 strcpy(p + 2, u->user);
388 p += p[1];
389
390 strcpy(pass, u->pass);
391 while (pw_len % 16)
392 pass[pw_len++] = 0; /* pad */
393
394 for (int j = 0; j < pw_len; j += 16)
395 {
396 MD5_CTX ctx;
397 MD5_Init(&ctx);
398 MD5_Update(&ctx, secret, strlen(secret));
399 if (j)
400 MD5_Update(&ctx, pass + j - 16, 16);
401 else
402 /* authenticator */
403 MD5_Update(&ctx, u->request + 4, 16);
404
405 uint8_t digest[16];
406 MD5_Final(digest, &ctx);
407
408 for (int k = 0; k < 16; k++)
409 pass[j + k] ^= digest[k];
410 }
411
412 *p = 2; /* password */
413 p[1] = pw_len + 2;
414 memcpy(p + 2, pass, pw_len);
415 p += p[1];
416 }
417 }
418
419 logmsg("done\n");
420
421 signal(SIGUSR1, catch);
422
423 struct stats *stats = xmmap(sizeof(struct stats) * instances);
424 memset(stats, 0, sizeof(struct stats) * instances);
425
426 logmsg("Spawning %d processes: ", instances);
427
428 int per_child = count / instances;
429 int rshift = 0;
430 for (u32 tmp = per_child; tmp & 0xff00; tmp >>= 1)
431 rshift++;
432
433 for (int i = 0, offset = 0; i < instances; i++)
434 {
435 int slack = i ? 0 : count % instances;
436
437 stats[i].total = per_child + slack;
438 if (!fork())
439 child(random + offset, per_child + slack, rshift, stats + i,
440 server_addr, port, limit / instances);
441
442 offset += per_child + slack;
443 }
444
445 logmsg("done\n");
446
447 /* wait for children to setup */
448 int ready = 0;
449 do {
450 ready = 0;
451 for (int i = 0; i < instances; i++)
452 ready += stats[i].ready;
453
454 sleep(1);
455 } while (ready < instances);
456
457 /* go! */
458 kill(0, SIGUSR1);
459
460 logmsg("Processing...\n");
461 logmsg(" total: ");
462
463 for (int i = 0; i < instances; i++)
464 logmsg("[%5d %5s %5s]", stats[i].total, "", "");
465
466 logmsg("\n");
467 logmsg(" out/in/err: ");
468
469 int done = 0;
470 do {
471 for (int i = 0; i < instances; i++)
472 logmsg("[%5d %5d %5d]", stats[i].out, stats[i].in,
473 stats[i].err);
474
475 logmsg("\n");
476
477 if (waitpid(-1, NULL, WNOHANG) > 0)
478 done++;
479
480 if (done < instances)
481 {
482 sleep(1);
483 logmsg(" ");
484 }
485 } while (done < instances);
486
487 int a_hist[MAX_ATTEMPTS + 1];
488 memset(&a_hist, 0, sizeof(a_hist));
489
490 u32 min = 0;
491 u32 max = 0;
492 u32 r_hist[64];
493 memset(&r_hist, 0, sizeof(r_hist));
494 int hsz = sizeof(r_hist) / sizeof(*r_hist);
495
496 for (int i = 0; i < count; i++)
497 {
498 if ((random[i].response != AccessAccept &&
499 random[i].response != AccessReject) ||
500 (random[i].attempts < 1 ||
501 random[i].attempts > MAX_ATTEMPTS))
502 {
503 a_hist[MAX_ATTEMPTS]++;
504 continue;
505 }
506
507 a_hist[random[i].attempts - 1]++;
508
509 u32 interval = random[i].end - random[i].begin;
510
511 if (!i || interval < min)
512 min = interval;
513
514 if (interval > max)
515 max = interval;
516
517 /* histogram in 1/10s intervals */
518 int t = interval / 10 + 0.5;
519 if (t > hsz - 1)
520 t = hsz - 1;
521
522 r_hist[t]++;
523 }
524
525 logmsg("Send attempts:\n");
526 for (int i = 0; i < MAX_ATTEMPTS; i++)
527 logmsg(" %6d: %d\n", i + 1, a_hist[i]);
528
529 logmsg(" failed: %d\n", a_hist[MAX_ATTEMPTS]);
530
531 logmsg("Response time in seconds (min %.2f, max %.2f)\n",
532 min / 100.0, max / 100.0);
533
534 for (int i = 0; i < hsz; i++)
535 {
536 if (i < hsz - 1)
537 logmsg(" %3.1f:", i / 10.0);
538 else
539 logmsg(" more:");
540
541 logmsg(" %6d\n", r_hist[i]);
542 }
543
544 return 0;
545 }
546
547 /* time in sec/100 since program commenced */
548 u32 now(void)
549 {
550 struct timeval t;
551 gettimeofday(&t, 0);
552 return (t.tv_sec - basetime) * 100 + t.tv_usec / 10000 + 1;
553 }
554
555 void child(struct user_list *users, int count, int rshift,
556 struct stats *stats, in_addr_t addr, int port, int limit)
557 {
558 int sockets = 1 << rshift;
559 unsigned rmask = sockets - 1;
560
561 int *sock = xmalloc(sizeof(int) * sockets);
562
563 fd_set r_in;
564 int nfd = 0;
565
566 FD_ZERO(&r_in);
567
568 for (int s = 0; s < sockets; s++)
569 {
570 if ((sock[s] = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) < 0)
571 {
572 fprintf(stderr, "can't create a UDP socket (%s)\n",
573 strerror(errno));
574
575 exit(1);
576 }
577
578 int flags = fcntl(sock[s], F_GETFL, 0);
579 fcntl(sock[s], F_SETFL, flags | O_NONBLOCK);
580
581 struct sockaddr_in svr;
582 memset(&svr, 0, sizeof(svr));
583 svr.sin_family = AF_INET;
584 svr.sin_port = htons(port);
585 svr.sin_addr.s_addr = addr;
586
587 connect(sock[s], (struct sockaddr *) &svr, sizeof(svr));
588
589 FD_SET(sock[s], &r_in);
590 if (sock[s] + 1 > nfd)
591 nfd = sock[s] + 1;
592 }
593
594 for (int i = 0; i < count; i++)
595 /* set identifier */
596 *((unsigned char *) users[i].entry->request + 1) = i >> rshift;
597
598 stats->ready = 1;
599 pause();
600
601 u32 out_timer = now();
602 int out_count = 0;
603
604 while ((stats->in + stats->err) < count)
605 {
606 u32 time_now = now();
607
608 while (out_timer + 10 < time_now)
609 {
610 out_timer += 10;
611 if (out_count > 0)
612 out_count -= limit;
613 }
614
615 for (int pass = 1; pass <= 2; pass++)
616 {
617 for (int i = 0; i < count && out_count < limit; i++)
618 {
619 if (users[i].response)
620 continue;
621
622 if (users[i].attempts)
623 {
624 if (users[i].retry > time_now)
625 continue;
626 }
627 else if (pass == 1)
628 {
629 /* retries only on the first pass */
630 continue;
631 }
632
633 struct user *e = users[i].entry;
634 if (write(sock[i & rmask], e->request, e->request_len)
635 != e->request_len)
636 break;
637
638 time_now = now();
639 out_count++;
640
641 if (!users[i].attempts)
642 {
643 users[i].begin = time_now;
644 stats->out++;
645 }
646
647 if (++users[i].attempts > MAX_ATTEMPTS)
648 {
649 users[i].response = AccessFail;
650 stats->err++;
651 continue;
652 }
653
654 users[i].retry = time_now + 200 + 100 * (1 << users[i].attempts);
655 }
656 }
657
658 struct timeval tv = { 0, 100000 };
659
660 fd_set r;
661 memcpy(&r, &r_in, sizeof(r));
662
663 if (select(nfd, &r, NULL, NULL, &tv) < 1)
664 continue;
665
666 char buf[4096];
667
668 for (int s = 0; s < sockets; s++)
669 {
670 if (!FD_ISSET(sock[s], &r))
671 continue;
672
673 int sz;
674
675 while ((sz = read(sock[s], buf, sizeof(buf))) > 0)
676 {
677 if (sz < 2)
678 {
679 fprintf(stderr, "short packet returned\n");
680 continue;
681 }
682
683 if (buf[0] != AccessAccept && buf[0] != AccessReject)
684 {
685 fprintf(stderr, "unrecognised response type %d\n",
686 (int) buf[0]);
687
688 continue;
689 }
690
691 int i = s | (((unsigned char) buf[1]) << rshift);
692 if (i < 0 || i > count)
693 {
694 fprintf(stderr, "bogus identifier returned %d\n", i);
695 continue;
696 }
697
698 if (!users[i].attempts)
699 {
700 fprintf(stderr, "unexpected identifier returned %d\n", i);
701 continue;
702 }
703
704 if (users[i].response)
705 continue;
706
707 int expect = (users[i].entry->flags & (F_FAKE|F_BAD))
708 ? AccessReject : AccessAccept;
709
710 if (buf[0] != expect)
711 fprintf(stderr, "unexpected response %d for user %s "
712 "(expected %d)\n", (int) buf[0], users[i].entry->user,
713 expect);
714
715 users[i].response = buf[0];
716 users[i].end = now();
717 stats->in++;
718 }
719 }
720 }
721
722 exit(0);
723 }