/* sniffit - a packet sniffer */
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <ctype.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <netdb.h>
#define __FAVOR_BSD
#include <netinet/tcp.h>
#include <netinet/ip.h>

void print_pkt(char *s, int len);

int main(int argc, char *argv[])
{
	int c = 0, opts = 0, port = 0;
	in_addr_t src = 0;

	if(argc < 2) {
		printf("invalid usage: %s -h for help\n", argv[0]);
		return 0;
	}

	while((c = getopt(argc, argv, "a p: s: ch")) != EOF) {
		if(opts != 0) break;

		switch(c) {
			case 'a':	/* monitor all data */
				opts = 'a';
			break;

			case 'p':	/* monitor by port */
				opts = 'p';
				if(!(port = atoi(optarg))) {
					printf("invalid port number: %s\n", optarg);
					return optind;
				}
				printf("Monitoring port %d\n", port);
			break;

			case 's':	/* monitor by source address */
				opts = 's';
				if(!strchr(optarg, '.')) {
					printf("invalid ip address: %s\n", optarg);
					return optind;
				}
				if(!(src = inet_addr(optarg))) {
					printf("invalid ip address: %s\n", optarg);
					return optind;
				}
				printf("Monitoring address %d\n", src);
			break;

			case 'c':	/* monitor outgoing connections */
				opts = 'c';
			break;

			case 'h':	/* help .. yeah, gotta have one of these! */
			default:
				printf("usage: %s [-a|-p|-s|-l|-h]\n", argv[0]);
				printf("\t-a              show all data\n");
				printf("\t-p <port>       monitor by port\n");
				printf("\t-s <src addr>   monitor by src addr\n");
				printf("\t-c              monitor outgoing connections\n");
				printf("\t-h              show this help\n");
				return 0;
			break;
		}
	}

	if(getuid()) {
		printf("%s: sorry, you must be root.\n", argv[0]);
		return getuid();
	}

	int sd = socket(PF_INET, SOCK_RAW, IPPROTO_TCP);
	if(sd < 0) {
		perror("could not create socket");
		return errno;
	}

	char buf[8192];
	struct ip *ip = (struct ip*) buf;
	struct tcphdr *tcp = (struct tcphdr *) (buf + sizeof(struct ip));

	memset(buf, 0, 8192);
	while(read(sd, buf, 8192) > 0) {
		int off = sizeof(*ip) + sizeof(*tcp) + tcp->th_off;
		if(opts == 'a') {
			print_pkt(buf + off, htons(ip->ip_len));
		} else if(opts == 'p') {
			if(ntohs(tcp->th_sport) == port || ntohs(tcp->th_dport) == port)
				print_pkt(buf + off, htons(ip->ip_len));
		} else if(opts == 's') {
			if(ip->ip_src.s_addr == src)
				print_pkt(buf + off, htons(ip->ip_len));
		} else if(opts == 'c') {
			if(tcp->th_flags == (TH_SYN|TH_ACK)) {
				printf("connection to %s:%d has been made.\n", inet_ntoa(ip->ip_src),
					ntohs(tcp->th_sport));
			} else if(tcp->th_flags == (TH_FIN|TH_ACK)) {
				printf("connection to %s:%d is being closed.\n", inet_ntoa(ip->ip_src),
					ntohs(tcp->th_sport));
			}
		}
		memset(buf, 0, 8192);
	}
	close(sd);
	return 0;
}

void print_pkt(char *s, int len)
{
	int j;

	for(j = 0; j < len; j++) {
		if(s[j] == 13) fputc('\n', stdout);
		else if(isprint(s[j])) fputc(s[j], stdout);
	}
	fflush(stdout);
}