/* $Id: ncat_listen.c 11795 2009-01-21 22:16:20Z david $ */

#include "nsock.h"
#include "ncat.h"
#include "sys_wrap.h"
#include "util.h"
#include "nbase.h"

#include <stdio.h>
#include <stdlib.h>
#ifndef WIN32
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <sys/wait.h>
#else
#include <fcntl.h>
#endif
#include <errno.h>
#include <string.h>
#include <sys/types.h>
#include <signal.h>
#include <assert.h>

#ifdef HAVE_OPENSSL
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif

static int conn_count;

/* reap child processes */
static void sig_chld(int signo)
{
    while (Waitpid(-1, NULL, WNOHANG) > 0)
        conn_count--;
}

static int ncat_listen_tcp()
{
    int listen_sock, new_fd,  nbytes,   fds_ready = 0, x, y;
    char    buf[DEFAULT_TCP_BUF_LEN] = {0};
    pid_t   pid;
    fd_set  master, read_fds;
    fd_list_t   fdlist;
    struct sockaddr_storage remotess;
    socklen_t sslen = sizeof(remotess);
#ifdef HAVE_OPENSSL
    SSL_CTX *ctx;
    SSL *tmpssl;
#else
    void *tmpssl = NULL;
#endif

    FD_ZERO(&master);
    zmem(&remotess, sizeof(remotess));
    zmem(&fdlist, sizeof(fdlist));

#ifndef WIN32
    /* Reap on SIGCHLD */
    Signal(SIGCHLD, sig_chld);
#endif

#ifdef HAVE_OPENSSL
    if (o.ssl)
        ctx = setup_ssl_listen();
#endif

    /* create listen socket */
    listen_sock = do_listen(SOCK_STREAM);

    /* Make our listening socket non-blocking because there are timing issues
     * which could cause us to block on accept() even though select() says it's
     * readable.  See UNPv1 2nd ed, p422 for more.
     */
    unblock_socket(listen_sock);

    FD_SET(listen_sock, &master);
    FD_SET(STDIN_FILENO, &master);

    /* we need a list of fds to keep current fdmax and send data to clients */
    init_fdlist(&fdlist, sadd(o.conn_limit, 2));
    add_fd(&fdlist, listen_sock, NULL);
    add_fd(&fdlist, STDIN_FILENO, NULL);

    /*
     * parent handles all incoming connections
     *
     *  if cmd exec mode
     *      forks child to exec a command
     *      keeps listening
     *  else
     *      reads from stdin and brodcasts to clients
     *      reads from clients and writes to stdout
     */
    while (1) {
        /* poll descriptors */
        if(verbose_flag > 1)
            Fprintf(stderr, "DEBUG: selecting, fdmax %d\n", fdlist.fdmax);
        read_fds = master;

        fds_ready = fselect(fdlist.fdmax + 1, &read_fds, NULL, NULL, NULL);

        if(verbose_flag > 1)
            Fprintf(stderr, "DEBUG: select returned %d fds ready\n", fds_ready);
        /*
         * FIXME: optimize this loop to look only at the fds in the fd list,
         * doing it this way means that if you have one descriptor that is very
         * large, say 500, and none close to it, that you'll loop many times for
         * nothing.
         */
        /* loop thru our descriptors, decrementing # ready as we go */
        for(x = 0; x <= fdlist.fdmax && fds_ready; x++){
            
            if (!FD_ISSET(x, &read_fds))
                continue;
                
            /* 
             * a descriptor is ready, handle cases:
             *
             * 1) listen socket
             * 2) stdin
             * 3) all others are clients
             */

            if (verbose_flag > 1)
                Fprintf(stderr, "DEBUG: fd %d is ready\n", x);

            if (x == listen_sock) {
                new_fd = accept(listen_sock, (struct sockaddr *) &remotess, &sslen);
                    
                if(verbose_flag > 1)
                    Fprintf(stderr, "DEBUG: accept returned %d, errno %d\n", new_fd, errno);

                /* handle non-blocking error return */
                if(new_fd < 0) {
                    if(errno != EAGAIN && errno != EWOULDBLOCK)
                        die("accept");

                    close(new_fd);  /* no Close() in case fd is bunk */
                    fds_ready--;
                    continue;
                }

                /* check total connection count and deny list */
                if(conn_count >= o.conn_limit || !allow_access(&remotess)) {
                    
                    Close(new_fd);
                    if(verbose_flag > 1)
                        Fprintf(stderr, "DEBUG: New connection denied: %s\n",
                                    (conn_count >= o.conn_limit) ? 
                                     "Max connections reached" : 
                                     "ACL denial");
                    fds_ready--;
                    continue;
                }
                    
                conn_count++;

#ifdef HAVE_OPENSSL
                if (o.ssl) {
                    tmpssl = new_ssl(new_fd);
                    if (SSL_accept(tmpssl) != 1)
                        bye("SSL_accept(): %s", ERR_error_string(ERR_get_error(), NULL));
                }
#endif

                /*
                 * are we executing a command? if so then don't add this guy
                 * to our descriptor list or set.
                 */
                if (o.cmdexec) {
                    pid = Fork();
            
                    if (pid == 0) {
                        if ((netexec(new_fd, o.cmdexec)) == -1)
                            die("execv");
                    }

                    /* Parent only reaches here */
                    Close(new_fd);
                    if(verbose_flag > 1)
                        Fprintf(stderr, "DEBUG: listen tcp: parent closed %d\n", new_fd);
                } else {
                    /* add to our lists */
                    FD_SET(new_fd, &master);
                    add_fd(&fdlist, new_fd, tmpssl);
                }
                if(verbose_flag > 1)
                    Fprintf(stderr, "DEBUG: listen tcp: accepted new connection\n");

            } else if(x == STDIN_FILENO) {
                /*
                 * dump from standard input to all clients
                 */

                /* Don't read stdin if there aren't any clients listening */
                if (fdlist.nfds == 2)
                    continue;

                /* crlf is 0 or 1, so subtracting 1 gives us our extra room */
                nbytes = Read(STDIN_FILENO, buf, sizeof(buf) - o.crlf);
                if(nbytes == 0)
                    bye("Bye now");

                if (o.crlf && buf[nbytes - 1] == '\n' && buf[nbytes - 2] != '\r') {
                    memcpy(&buf[nbytes - 1], "\r\n", 2);
                    nbytes++;
                }

                buf[nbytes] = 0;

                if(o.linedelay)
                    ncat_delay_timer(o.linedelay);

                /* loop through all of our clients and send them data */
                for(y = 0; y < fdlist.nfds; y++) {

                    /* don't write to stdin/out/err */
                    if(fdlist.fds[y].fd == STDIN_FILENO ||
                       fdlist.fds[y].fd == STDOUT_FILENO ||
                       fdlist.fds[y].fd == STDERR_FILENO ||
                       fdlist.fds[y].fd == listen_sock)
                        continue;

#ifdef HAVE_OPENSSL
                    if (o.ssl)
                        nbytes = SSL_write(fdlist.fds[y].ssl, buf, nbytes);
                    else
#endif
                        nbytes = send(fdlist.fds[y].fd, buf, nbytes, 0);

                    if(verbose_flag > 1)
                        Fprintf(stderr, "DEBUG: wrote %d bytes to client fd %d\n", nbytes, fdlist.fds[y].fd);
                }
		    
                /* dump to proper places */
                if (o.normlogfd != -1)
                    Write(o.normlogfd, buf, nbytes);
		    
                if (o.hexlogfd != -1)
		    ncat_hexdump(o.hexlogfd, buf, nbytes);
            } else {
                struct fdinfo *fdn = get_fdinfo(&fdlist, x);

                assert(fdn);

                /* 
                 * it's a client, dump from them to standard output
                 */
#ifdef HAVE_OPENSSL
            readagain:

                if (o.ssl)
                    nbytes = SSL_read(fdn->ssl, buf, sizeof(buf));
                else
#endif
                    nbytes = recv(x, buf, sizeof(buf), 0);

                if(verbose_flag > 1)
                    Fprintf(stderr, "DEBUG: client sent %d bytes\n", nbytes);

                buf[nbytes] = 0;

                /* are they closing down ? */
                if(nbytes <= 0) {
#ifdef HAVE_OPENSSL
                    if (o.ssl) {
                        SSL_shutdown(fdn->ssl);
                        SSL_free(fdn->ssl);
                    }
#endif
                    rm_fd(&fdlist, x);
                    FD_CLR(x, &master);
                    close(x);   /* if error in read(), don't quit prog */
                    conn_count--;
                } else {
                    /* nope they sent some data */

                    if(o.linedelay)
                        ncat_delay_timer(o.linedelay);

                    if (o.telnet)
                        dotelnet(x, (unsigned char *) buf, nbytes);
    
                    Write(STDOUT_FILENO, buf, nbytes);
	    
                    /* dump to proper places */
                    if (o.normlogfd != -1)
                        Write(o.normlogfd, buf, nbytes);

                    if (o.hexlogfd != -1)
		        ncat_hexdump(o.hexlogfd, buf, nbytes);

#ifdef HAVE_OPENSSL
                    /* SSL can buffer our input, so doing another select()
                     * won't necessarily work for us.  We jump back up to
                     * read any more data we can grab now
                     */
                    if (o.ssl && SSL_pending(fdn->ssl))
                        goto readagain;
#endif
                }
            }

            /* optimize the loop a bit */
            fds_ready--;
            
        }/* for */

    }/* while */

    return 0;
}

/* This is sufficiently different from the TCP code (wrt SSL, etc) that it
 * resides in its own simpler function
 */
static int ncat_listen_udp()
{
    int sockfd, fdmax, nbytes, fds_ready;
    char buf[DEFAULT_UDP_BUF_LEN] = {0};
    fd_set master,  read_fds;
    struct sockaddr_storage remotess;
    socklen_t sslen = sizeof(remotess);

    FD_ZERO(&master);
    read_fds = master;

    /* Initialize remotess struct so recvfrom() doesn't hit the fan.. */
    zmem(&remotess, sizeof(remotess));
    remotess.ss_family = o.af;

    /* create the UDP listen socket */
    sockfd = do_listen(SOCK_DGRAM);

    while (1) {
        /*
         * We just peek so we can get the client connection details without
         * removing anything from the queue. Sigh.
         */
        nbytes = Recvfrom(sockfd, buf, sizeof(buf), MSG_PEEK,
                            (struct sockaddr *) &remotess, &sslen);

        /* check deny list */
        if (!allow_access(&remotess)) {
            if (verbose_flag > 1)
                Fprintf(stderr, "DEBUG: New connection denied: ACL denial\n");

            /* Dump the current datagram */
            Recv(sockfd, buf, sizeof(buf), 0);

            continue;
        }

        break;
    }
                    
    /* 
     * We're using connected udp. This has the down side of only
     * being able to handle one udp client at a time
     */
    Connect(sockfd, (struct sockaddr *) &remotess, sslen);

    /* clean slate for buf */
    zmem(buf, sizeof(buf));

    /* are we executing a command? then do it */
    if (o.cmdexec)
        if ((netexec(sockfd, o.cmdexec)) == -1)
            die("execv");

    FD_SET(sockfd, &master);
    FD_SET(STDIN_FILENO, &master);
    fdmax = sockfd;

    /* stdin -> socket and socket -> stdout */
    while (1) {
        read_fds = master;

        if(verbose_flag > 1)
            Fprintf(stderr, "DEBUG: udp select'ing\n");

        fds_ready = fselect(fdmax + 1, &read_fds, NULL, NULL, NULL);

        if (FD_ISSET(STDIN_FILENO, &read_fds)) {
            nbytes = Read(STDIN_FILENO, buf, sizeof(buf));
            if(nbytes == 0)
                return 0;
            send(sockfd, buf, nbytes, 0);
	}
        if (FD_ISSET(sockfd, &read_fds)) {
            nbytes = recv(sockfd, buf, sizeof(buf), 0);
            if(nbytes == 0){
                close(sockfd);
                return 0;
            }
            Write(STDOUT_FILENO, buf, nbytes);
        }
        
        zmem(buf, sizeof(buf));
    }

    return 0;
}

int ncat_listen()
{
    if (o.httpserver)
	    return ncat_http_server();
    else if (o.udp)
	    return ncat_listen_udp();
    else
	    return ncat_listen_tcp();

    /* unreached */
    return 1;
}



