/* $Id: util.c 11801 2009-01-21 23:57:59Z david $ */

#include "sys_wrap.h"
#include "util.h"
#include "ncat.h"

#include <stdio.h>
#ifdef WIN32
#include <iphlpapi.h>
#endif
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>

#if HAVE_SYS_STAT_H
#include <sys/stat.h>
#endif
#if HAVE_FCNTL_H
#include <fcntl.h>
#endif
#if HAVE_UNISTD_H
#include <unistd.h>
#endif

int
isip(char *host)
{
  int a, b, c, d;

  if (!sscanf(host, "%d.%d.%d.%d", &a, &b, &c, &d))
    return 0;

  if (a < 1 || a > 255)
    return 0;

  if (b < 0 || b > 255 || c < 0 || c > 255 || d < 0 || d > 255)
    return 0;

  return 1;
}

/* safely add 2 size_t */
size_t sadd(size_t l, size_t r)
{
    size_t  t;

    t = l + r;
    if(t < l)
        bye("integer overflow %lu + %lu", (u_long)l, (u_long)r);
    return t;
}

/* safely multiply 2 size_t */
size_t smul(size_t l, size_t r)
{
    size_t  t;

    t = l * r;
    if(l && t / l != r)
        bye("integer overflow %lu * %lu", (u_long)l, (u_long)r);
    return t;
}

#ifdef WIN32
void windows_init()
{
	WORD werd;
	WSADATA data;

	werd = MAKEWORD( 2, 2 );
	if( (WSAStartup(werd, &data)) !=0 )
		bye("Failed to start WinSock.");
}
#endif

void die(char *err)
{
    perror(err);
    exit(EXIT_FAILURE);
}

/* adds newline for you */
void bye(const char *fmt, ...)
{
    char    *nlfmt = (char *)Malloc(strlen(fmt) + 2);
    va_list ap;

    strcpy(nlfmt, fmt);
    strcat(nlfmt, "\n");

    va_start(ap, fmt);
    vfprintf(stderr, nlfmt, ap);
    va_end(ap);

    free(nlfmt);
    exit(EXIT_FAILURE);
}

/* zero out some mem, bzero() is deprecated */
void zmem(void *mem, size_t n)
{
    memset(mem, 0, n);
}

/* Converts an IP address given in a sockaddr_storage to an IPv4 or
   IPv6 IP address string.  Since a static buffer is returned, this is
   not thread-safe and can only be used once in calls like printf() 
*/
const char *inet_socktop(struct sockaddr_storage *ss) {
  static char buf[INET6_ADDRSTRLEN];
  struct sockaddr_in *sin = (struct sockaddr_in *) ss;
#if HAVE_IPV6
  struct sockaddr_in6 *sin6 = (struct sockaddr_in6 *) ss;
#endif

  if (inet_ntop(sin->sin_family, (ss->ss_family == AF_INET)? 
                (char *) &sin->sin_addr : 
#if HAVE_IPV6
		(char *) &sin6->sin6_addr, 
#else
                (char *) NULL,
#endif /* HAVE_IPV6 */
                buf, sizeof(buf)) == NULL) {
    bye("Failed to convert address to presentation format!  Error: %s", strerror(socket_errno()));
  }
  return buf;
}

/* Returns the port number in HOST BYTE ORDER based on the ss's family */
unsigned short inet_port(struct sockaddr_storage *ss)
{
	if (ss->ss_family == AF_INET)
		return ntohs(((struct sockaddr_in *) ss)->sin_port);
#ifdef HAVE_IPV6
	else if (ss->ss_family == AF_INET6)
		return ntohs(((struct sockaddr_in6 *) ss)->sin6_port);
#endif

	bye("Invalid address family passed to inet_port()");
	return 0;
}

int do_listen(int type)
{
    int sock = 0, option_on = 1;

    if(type != SOCK_STREAM && type != SOCK_DGRAM)
        return -1;

    sock = Socket(srcaddr.ss_family, type, 0);

    Setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &option_on, sizeof(int));

    Bind(sock, (struct sockaddr *) &srcaddr, (int) srcaddrlen);

    if(type == SOCK_STREAM)
        Listen(sock, BACKLOG);

    if(verbose_flag > 0)
        Fprintf(stderr, "Listening on %s:%hu\n",
                inet_socktop(&srcaddr), inet_port(&srcaddr));

    return sock;
}

char *buildsrcrte(struct in_addr dstaddr, struct in_addr routes[],
                  int numroutes, int ptr, size_t *len)
{
	int x;
	char *opts, *p;

	*len = (numroutes + 1) * sizeof(struct in_addr) + 4;

	if (numroutes > 8)
		bye("Bad number of routes passed to buildsrcrte()");

	opts = (char *) Malloc(*len);
	p = opts;

	zmem(opts, *len);

	*p++ = 0x01; /* IPOPT_NOP, for alignment */
	*p++ = 0x83; /* IPOPT_LSRR */
	*p++ = (char) (*len - 1); /* subtract nop */
	*p++ = (char) ptr;

	for (x = 0; x < numroutes; x++) {
		memcpy(p, &routes[x], sizeof(routes[x]));
		p += sizeof(routes[x]);
	}

	memcpy(p, &dstaddr, sizeof(dstaddr));

	return opts;
}

int allow_access(struct sockaddr_storage *ss)
{
    struct in_addr addr;

    /* Currently only IPv4 host access control is allowed */
    if (ss->ss_family != AF_INET)
        return 1;

    addr = ((struct sockaddr_in *) ss)->sin_addr;

    if (o.allow || o.allowfile || o.deny || o.denyfile) {
        if (o.allow) {
            if ((ncat_hostaccess(o.allow, NULL, inet_ntoa(addr))) == 0)
                return 0;
        }
				
        if (o.allowfile) {
            if ((ncat_hostaccess(NULL, o.allowfile, inet_ntoa(addr))) == 0)
                return 0;
        }
				
        if (o.deny) {
            if ((ncat_hostaccess(o.deny, NULL, inet_ntoa(addr))) == 1)
                return 0;
        }
				
        if (o.denyfile) {
            if ((ncat_hostaccess(NULL, o.denyfile, inet_ntoa(addr))) == 1)
                return 0;
        }
    }
    
    return 1;
}

/*
 * ugly code to maintain our list of fds so we can have proper fdmax for
 * select().  really this should be generic list code, not this silly bit of
 * stupidity. -sean
 */

/* add a descriptor to our list */
int add_fd(fd_list_t *fdl, int fd, void *ssl)
{
    if(fdl->nfds >= fdl->maxfds)
        return -1;

    fdl->fds[fdl->nfds].fd = fd;
#ifdef HAVE_OPENSSL
    fdl->fds[fdl->nfds].ssl = (SSL *) ssl;
#endif

    fdl->nfds++;

    if(fd > fdl->fdmax)
        fdl->fdmax = fd;

    if(verbose_flag > 1)
        Fprintf(stderr, "DEBUG: added fd %d to list, nfds %d, maxfd %d\n", fd, fdl->nfds, fdl->fdmax);
    return 0;
}

/* remove a descriptor from our list */
int rm_fd(fd_list_t *fdl, int fd)
{
    int x = 0, last = fdl->nfds;

    /* make sure we have a list */
    if(last == 0){
        Fprintf(stderr, "Program bug: Trying to remove fd from list with no fds\n");
        return -1;
    }

    /* find the fd in the list */
    for(x = 0; x < last; x++)
        if(fdl->fds[x].fd == fd)
            break;

    /* make sure we found it */
    if(x == last){
        Fprintf(stderr, "Program bug: fd (%d) not on list\n", fd);
        return -1;
    }

    /* remove it, does nothing if(last == 1) */
    if(verbose_flag > 1)
        Fprintf(stderr, "DEBUG: swapping fd[%d] (%d) with fd[%d] (%d)\n", x,
                                fdl->fds[x].fd, last - 1, fdl->fds[last - 1].fd);
    fdl->fds[x] = fdl->fds[last - 1];
    
    fdl->nfds--;

    /* was it the max */
    if(fd == fdl->fdmax)
        fdl->fdmax = get_maxfd(fdl);

    if(verbose_flag > 1)
        Fprintf(stderr, "DEBUG: removed fd %d from list, nfds %d, maxfd %d\n", fd, fdl->nfds, fdl->fdmax);
    return 0;
}

/* find the max descriptor in our list */
int get_maxfd(fd_list_t *fdl)
{
    int x = 0,  max = -1,   nfds = fdl->nfds;

    for(x = 0; x < nfds; x++)
        if(fdl->fds[x].fd > max)
            max = fdl->fds[x].fd;

    return max;
}

struct fdinfo *get_fdinfo(fd_list_t *fdl, int fd)
{
	int x;

	for (x = 0; x < fdl->nfds; x++)
		if (fdl->fds[x].fd == fd)
			return &fdl->fds[x];

	return NULL;
}

void init_fdlist(fd_list_t *fdl, int maxfds)
{
    fdl->fds = (struct fdinfo *) Calloc(maxfds, sizeof(struct fdinfo));
    fdl->nfds = 0;
    fdl->fdmax = -1;
    fdl->maxfds = maxfds;

    if(verbose_flag > 1)
        Fprintf(stderr, "DEBUG: initialized fdlist with %d maxfds\n", maxfds);
}

void free_fdlist(fd_list_t *fdl)
{
    free(fdl->fds);
    fdl->nfds = 0;
    fdl->fdmax = -1;
}

