#!/usr/bin/env python
# -*- coding: utf-8 -*-

# Copyright (C) 2005 Insecure.Com LLC.
#
# Author: Adriano Monteiro Marques <py.adriano@gmail.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA

import datetime
import os
import subprocess
import sys
import tempfile
import xml.sax

from zenmapCore.Name import APP_NAME
from zenmapCore.NmapParser import NmapParserSAX
from zenmapCore.UmitConf import PathsConfig
from zenmapCore.UmitLogging import log
import zenmapCore.Paths

# The [paths] configuration from zenmap.conf, used to get ndiff_command_path.
paths_config = PathsConfig()

class ScanDiff(object):
    def __init__(self):
        self.a_start = None
        self.b_start = None
        self.host_diffs = []

class HostDiff(object):
    def __init__(self):
        self.addresses = []
        self.hostnames = []
        self.diff = []

class DiffHunk(object):
    def __init__(self, type):
        self.type = type

class NdiffParseException(Exception):
    pass

# Unique identifiers for each hunk type.
(HOST_STATE_CHANGE, HOST_ADDRESS_ADD, HOST_ADDRESS_REMOVE,
    HOST_HOSTNAME_ADD, HOST_HOSTNAME_REMOVE,
    PORT_ID_CHANGE, PORT_STATE_CHANGE) = range(0, 7)
    
class NdiffContentHandler(xml.sax.handler.ContentHandler):
    # A dict mapping hunk elements to types.
    HUNK_TYPE_MAP = {
        "host-state-change": HOST_STATE_CHANGE,
        "host-address-add": HOST_ADDRESS_ADD,
        "host-address-remove": HOST_ADDRESS_REMOVE,
        "host-hostname-add": HOST_HOSTNAME_ADD,
        "host-hostname-remove": HOST_HOSTNAME_REMOVE,
        "port-id-change": PORT_ID_CHANGE,
        "port-state-change": PORT_STATE_CHANGE
    }

    HUNK_ELEMENTS = HUNK_TYPE_MAP.keys()

    def __init__(self, scan_diff):
        self.scan_diff = scan_diff
        self.current_host_diff = None
        self.current_hunk = None

        self.num_scandiff = 0
        self.element_stack = []

    def parent(self):
        """Return the parent of the current element, or None if this is the root
        element."""
        if len(self.element_stack) == 0:
            return None
        return self.element_stack[-1]

    def startElement(self, name, attrs):
        if name == "nmapdiff":
            if self.parent() != None:
                raise NdiffParseException("Unexpected %s element" % name)
        elif name == "scandiff":
            if self.parent() != "nmapdiff":
                raise NdiffParseException("Unexpected %s element" % name)
            if self.num_scandiff > 0:
                raise NdiffParseException("Only one scandiff element is allowed")
            a_start = attrs.get("a-start")
            if a_start is not None:
                try:
                    self.scan_diff.a_start = datetime.datetime.fromtimestamp(float(a_start))
                except ValueError:
                    raise NdiffParseException("Value of a-start attribute is not an timestamp: %s" % a_start)
            b_start = attrs.get("b-start")
            if b_start is not None:
                try:
                    self.scan_diff.b_start = datetime.datetime.fromtimestamp(float(b_start))
                except ValueError:
                    raise NdiffParseException("Value of b-start attribute is not an timestamp: %s" % b_start)
        elif name == "host":
            if self.parent() != "scandiff":
                raise NdiffParseException("Unexpected %s element" % name)
            self.current_host_diff = HostDiff()
        elif name == "address":
            try:
                address = (attrs["addrtype"], attrs["addr"])
            except KeyError, e:
                raise NdiffParseException("%s element is missing the %s attribute" % (name, e.args[0]))
            if self.parent() == "host":
                self.current_host_diff.addresses.append(address)
            elif self.parent() == "host-address-add" or self.parent() == "host-address-remove":
                if hasattr(self.current_hunk, "address"):
                    raise NdiffParseException("Only one address element is allowed inside %s" % self.parent())
                self.current_hunk.address = address
            else:
                raise NdiffParseException("Unexpected %s element" % name)
        elif name == "hostname":
            try:
                hostname = attrs["name"]
            except KeyError, e:
                raise NdiffParseException("%s element is missing the %s attribute" % (name, e.args[0]))
            if self.parent() == "host":
                self.current_host_diff.hostnames.append(hostname)
            elif self.parent() == "host-hostname-add" or self.parent() == "host-hostname-remove":
                if hasattr(self.current_hunk, "hostname"):
                    raise NdiffParseException("Only one hostname element is allowed inside %s" % self.parent())
                self.current_hunk.hostname = hostname
            else:
                raise NdiffParseException("Unexpected %s element" % name)
        elif name in self.HUNK_ELEMENTS:
            if self.parent() != "host":
                raise NdiffParseException("Unexpected %s element" % name)
            self.current_hunk = DiffHunk(self.HUNK_TYPE_MAP[name])
            if name == "host-state-change":
                try:
                    self.current_hunk.a_state = attrs["a-state"]
                    self.current_hunk.b_state = attrs["b-state"]
                except KeyError, e:
                    raise NdiffParseException("%s element is missing the %s attribute" % (name, e.args[0]))
            elif name == "host-address-add":
                pass
            elif name == "host-address-remove":
                pass
            elif name == "host-hostname-add":
                pass
            elif name == "host-hostname-remove":
                pass
            elif name == "port-id-change":
                try:
                    try:
                        self.current_hunk.a_spec = (int(attrs["a-portid"]), attrs["a-protocol"])
                    except ValueError:
                        raise NdiffParseException("a-portid attribute of %s element is not an integer" % name)
                    try:
                        self.current_hunk.b_spec = (int(attrs["b-portid"]), attrs["b-protocol"])
                    except ValueError:
                        raise NdiffParseException("b-portid attribute of %s element is not an integer" % name)
                except KeyError, e:
                    raise NdiffParseException("%s element is missing the %s attribute" % (name, e.args[0]))
            elif name == "port-state-change":
                try:
                    try:
                        self.current_hunk.spec = (int(attrs["portid"]), attrs["protocol"])
                    except ValueError:
                        raise NdiffParseException("portid attribute of %s element is not an integer" % name)
                    self.current_hunk.a_state = attrs["a-state"]
                    self.current_hunk.b_state = attrs["b-state"]
                except KeyError, e:
                    raise NdiffParseException("%s element is missing the %s attribute" % (name, e.args[0]))
            else:
                assert False, "Element %s in HUNK_ELEMENTS was not handled."
        else:
            raise NdiffParseException("Unknown element: %s" % name)

        self.element_stack.append(name)

    def endElement(self, name):
        self.element_stack.pop()
        if name == "scandiff":
            self.num_scandiff += 1
        elif name == "host-address-add" or name == "host-address-remove":
            if not hasattr(self.current_hunk, "address"):
                raise NdiffParseException("No address subelement found inside %s" % name)
        elif name == "host-hostname-add" or name == "host-hostname-remove":
            if not hasattr(self.current_hunk, "hostname"):
                raise NdiffParseException("No hostname subelement found inside %s" % name)

        if name == "host":
            self.scan_diff.host_diffs.append(self.current_host_diff)
            self.current_host_diff = None
        if name in self.HUNK_ELEMENTS:
            self.current_host_diff.diff.append(self.current_hunk)
            self.current_hunk = None

def get_path():
    """Return a value for the PATH environment variable that is appropriate
    for the current platform. It will be the PATH from the environment plus
    possibly some platform-specific directories."""
    path_env = os.getenv("PATH")
    if path_env is None:
        search_paths = []
    else:
        search_paths = path_env.split(os.pathsep)
    for path in zenmapCore.Paths.get_extra_executable_search_paths():
        if path not in search_paths:
            search_paths.append(path)
    return os.pathsep.join(search_paths)

class NdiffCommand(subprocess.Popen):
    def __init__(self, filename_a, filename_b, temporary_filenames = []):
        self.temporary_filenames = temporary_filenames

        search_paths = get_path()
        env = dict(os.environ)
        env["PATH"] = search_paths

        command_list = [paths_config.ndiff_command_path, "--xml", filename_a, filename_b]
        self.stdout_file = tempfile.TemporaryFile(mode = "rb", prefix = APP_NAME + "-ndiff-", suffix = ".xml")

        log.debug("Running command: %s" % repr(command_list))
        # See zenmapCore.NmapCommand.py for an explanation of the shell argument.
        subprocess.Popen.__init__(self, command_list, stdout = self.stdout_file, stderr = subprocess.PIPE, env = env, shell = (sys.platform == "win32"))

    def get_scan_diff(self):
        self.wait()

        scan_diff = ScanDiff()
        self.stdout_file.seek(0)
        xml.sax.parse(self.stdout_file, NdiffContentHandler(scan_diff))

        return scan_diff

    def close(self):
        """Clean up temporary files."""
        self.stdout_file.close()
        for filename in self.temporary_filenames:
            log.debug("Remove temporary diff file %s." % filename)
            os.remove(filename)
        self.temporary_filenames = []

    def kill(self):
        self.close()

def ndiff(scan_a, scan_b):
    """Run Ndiff on two scan results, which may be filenames or NmapParserSAX
    objects, and return a running NdiffCommand object."""
    temporary_filenames = []

    if isinstance(scan_a, NmapParserSAX):
        fd, filename_a = tempfile.mkstemp(prefix = APP_NAME + "-diff-", suffix = ".xml")
        temporary_filenames.append(filename_a)
        f = os.fdopen(fd, "wb")
        scan_a.write_xml(f)
        f.close()
    else:
        filename_a = scan_a

    if isinstance(scan_b, NmapParserSAX):
        fd, filename_b = tempfile.mkstemp(prefix = APP_NAME + "-diff-", suffix = ".xml")
        temporary_filenames.append(filename_b)
        f = os.fdopen(fd, "wb")
        scan_b.write_xml(f)
        f.close()
    else:
        filename_b = scan_b

    return NdiffCommand(filename_a, filename_b, temporary_filenames)

def partition_port_state_changes(diff):
    """Partition a list of PORT_STATE_CHANGE diff hunks into equivalence classes
    based on the tuple (protocol, a_state, b_state). The partition is returned
    as a list of lists of hunks."""
    transitions = {}
    for hunk in diff:
        if hunk.type != PORT_STATE_CHANGE:
            continue
        a_state = hunk.a_state
        b_state = hunk.b_state
        protocol = hunk.spec[1]
        transitions.setdefault((protocol, a_state, b_state), []).append(hunk)
    return transitions.values()

def consolidate_port_state_changes(diff, threshold = 0):
    """Return a list of list of PORT_STATE_CHANGE diff hunks, where each list
    contains hunks with the same partition and state change. A group of hunks is
    returned in the list of lists only when its length exceeds threshold. Any
    hunks moved to the list of lists are removed from diff in place."""
    partition = partition_port_state_changes(diff)
    consolidated = []
    for group in partition:
        if len(group) > threshold:
            for hunk in group:
                diff.remove(hunk)
            consolidated.append(group)
    return consolidated
