#!/usr/bin/env python

# Ndiff
#
# This programs reads two Nmap XML files and displays a list of their
# differences.
#
# Copyright 2008 Insecure.Com LLC
# Ndiff is distributed under the same license as Nmap. See the file COPYING or
# http://nmap.org/data/COPYING. See http://nmap.org/book/man-legal.html for more
# details.
#
# David Fifield
# based on a design by Michael Pattrick

import datetime
import getopt
import sys
import time
import xml.sax
import xml.dom.minidom

# Port state transitions with more than this many occurrences are consolidated
# into one text output entry.
PORT_STATE_CHANGE_CONSOLIDATION_THRESHOLD = 10
# Consolidated port state ranges whose *character length* is greater than this
# will be doubly consolidated into just a count of ports in text output.
PORT_STATE_CHANGE_DOUBLE_CONSOLIDATION_CHAR_THRESHOLD = 80

class Port(object):
    """A single port, consisting of a port specification, a state, and a service
    version. A specification, or "spec," is the 2-tuple (number, protocol). So
    (10, "tcp") corresponds to the port 10/tcp. Port states are strings."""
    # This represents an "unknown" port state, the state a port is in when it
    # has not been scanned. It must not compare equal with any real Nmap port
    # state like "open", "closed", etc. For future compatibility's sake, always
    # compare against Port.UNKNOWN, not the literal string "unknown".
    UNKNOWN = "unknown"

    def __init__(self, spec, state = None):
        self.spec = spec
        self.state = state or Port.UNKNOWN
        self.service = Service()

    def get_state_string(self):
        return Port.state_to_string(self.state)

    def state_to_string(state):
        return unicode(state)
    state_to_string = staticmethod(state_to_string)

    def spec_to_string(spec):
        return u"%d/%s" % spec
    spec_to_string = staticmethod(spec_to_string)

class PortDict(dict):
    """This is a dict that creates and inserts a new Port on demand when a
    looked-up value isn't found."""
    def __getitem__(self, key):
        try:
            port = dict.__getitem__(self, key)
        except KeyError:
            port = Port(key)
            self[key] = port
        return port

    def __len__(self):
        raise ValueError(u"__len__ is not defined for objects of type PortDict.")

class Service(object):
    """A service version as determined by -sV scan. Also contains the looked-up
    port name if -sV wasn't used."""
    def __init__(self):
        self.name = None
        self.product = None
        self.version = None
        self.extrainfo = None
        # self.hostname = None
        # self.ostype = None
        # self.devicetype = None
        # self.tunnel = None

    def __eq__(self, other):
        return self.name == other.name \
            and self.product == other.product \
            and self.version == other.version \
            and self.extrainfo == other.extrainfo

    def to_string(self):
        """Get a string like in the SERVICE column of Nmap output."""
        if self.name is None:
            return u""
        else:
            return self.name

    def version_to_string(self):
        """Get a string like in the VERSION column of Nmap output."""
        parts = []
        if self.product is not None:
            parts.append(self.product)
        if self.version is not None:
            parts.append(self.version)
        if self.extrainfo is not None:
            parts.append(u"(%s)" % self.extrainfo)
        return u" ".join(parts)

class Host(object):
    """A single host, with a state (unknown, up, or down), addresses, and a
    dict mapping port specs to Ports."""
    # This represents an "unknown" host state, the state a host is in when it
    # hasn't been scanned. It must not compare equal with the real Nmap host
    # states "up" and "down". For future compatibility's sake, always compare
    # against Host.UNKNOWN, not the literal string "unknown".
    UNKNOWN = "unknown"

    def __init__(self):
        self.state = self.UNKNOWN
        # Addresses are represented as (address_type, address) tuples.
        self.addresses = []
        self.hostnames = []
        self.ports = PortDict()
        self.os = []

    def get_id(self):
        """Return an id that is used to determine if hosts are "the same" across
        scans."""
        if len(self.addresses) > 0:
            return sorted(self.addresses)[0]
        if len(self.hostnames) > 0:
            return sorted(self.hostnames)[0]
        return id(self)

    def format_name(self):
        """Return a human-readable identifier for this host."""
        address = None
        for address_type in (u"ipv4", u"ipv6", u"mac"):
            addrs = [addr for addr in self.addresses if addr[0] == address_type]
            if len(addrs) > 0:
                address = addrs[0][1]
                break
        hostname = None
        if len(self.hostnames) > 0:
            hostname = self.hostnames[0]
        if hostname is not None:
            if address is not None:
                return u"%s (%s)" % (hostname, address)
            else:
                return hostname
        elif address is not None:
            return address

        return unicode(id(self))

    def add_port(self, port):
        self.ports[port.spec] = port

    def swap_ports(self, spec_a, spec_b):
        """Swap the ports given by the two specs. This is used when a service is
        moved from one port to another."""
        port_a = self.ports[spec_a]
        port_b = self.ports[spec_b]
        assert port_a.spec == spec_a
        assert port_b.spec == spec_b
        port_a.spec, port_b.spec = port_b.spec, port_a.spec
        self.ports[spec_a], self.ports[spec_b] = \
            self.ports[spec_b], self.ports[spec_a]

    def get_known_ports(self):
        """Return a list of all this host's Ports that are not in the unknown
        state."""
        return [p for p in self.ports.values() if p.state != Port.UNKNOWN]

    def add_address(self, address_type, address):
        if address not in self.addresses:
            self.addresses.append((address_type, address))

    def remove_address(self, address_type, address):
        try:
            self.addresses.remove((address_type, address))
        except ValueError:
            pass

    def add_hostname(self, hostname):
        if hostname not in self.hostnames:
            self.hostnames.append(hostname)

    def remove_hostname(self, hostname):
        try:
            self.hostnames.remove(hostname)
        except ValueError:
            pass

class Scan(object):
    """A single Nmap scan, corresponding to a single invocation of Nmap. It is a
    container for a list of hosts. It also has utility methods to load itself
    from an Nmap XML file."""
    def __init__(self):
        self.start_date = None
        self.end_date = None
        self.hosts = []

    def load(self, f):
        """Load a scan from the Nmap XML in the file-like object f."""
        parser = xml.sax.make_parser()
        handler = NmapContentHandler(self)
        parser.setContentHandler(handler)
        parser.parse(f)

    def load_from_file(self, filename):
        """Load a scan from the Nmap XML file with the given filename."""
        f = open(filename, "r")
        try:
            self.load(f)
        finally:
            f.close()


# What follows are all the possible diff hunk types. Each is capable of redering
# itself to a string or to an XML DOM fragment.
class DiffHunk(object):
    """A DiffHunk is a single unit of a diff. Each one represents one atomic
    change: a host state change, port state change, and so on."""

    def to_string(self):
        """Return a human-readable string representation of this hunk."""
        raise Exception(u"The to_string method must be overridden.")

    def to_dom_fragment(self, document):
        """Return an XML DocumentFragment representation of this hunk."""
        raise Exception(u"The to_dom_fragment method must be overridden.")

class HostStateChangeHunk(DiffHunk):
    def __init__(self, a_state, b_state):
        self.a_state = a_state
        self.b_state = b_state

    def to_string(self):
        return u"Host is %s, was %s." % (self.b_state, self.a_state)

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-state-change")
        elem.setAttribute(u"a-state", self.a_state)
        elem.setAttribute(u"b-state", self.b_state)
        frag.appendChild(elem)
        return frag

class HostAddressAddHunk(DiffHunk):
    def __init__(self, address_type, address):
        self.address_type = address_type
        self.address = address

    def to_string(self):
        return u"Add %s address %s." % (self.address_type, self.address)

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-address-add")
        frag.appendChild(elem)
        address_elem = document.createElement(u"address")
        address_elem.setAttribute(u"addrtype", self.address_type)
        address_elem.setAttribute(u"addr", self.address)
        elem.appendChild(address_elem)
        return frag

class HostAddressRemoveHunk(DiffHunk):
    def __init__(self, address_type, address):
        self.address_type = address_type
        self.address = address

    def to_string(self):
        return u"Remove %s address %s." % (self.address_type, self.address)

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-address-remove")
        frag.appendChild(elem)
        address_elem = document.createElement(u"address")
        address_elem.setAttribute(u"addrtype", self.address_type)
        address_elem.setAttribute(u"addr", self.address)
        elem.appendChild(address_elem)
        return frag

class HostHostnameAddHunk(DiffHunk):
    def __init__(self, hostname):
        self.hostname = hostname

    def to_string(self):
        return u"Add hostname %s." % self.hostname

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-hostname-add")
        frag.appendChild(elem)
        hostname_elem = document.createElement(u"hostname")
        hostname_elem.setAttribute(u"name", self.hostname)
        elem.appendChild(hostname_elem)
        return frag

class HostHostnameRemoveHunk(DiffHunk):
    def __init__(self, hostname):
        self.hostname = hostname

    def to_string(self):
        return u"Remove hostname %s." % self.hostname

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-hostname-remove")
        frag.appendChild(elem)
        hostname_elem = document.createElement("hostname")
        hostname_elem.setAttribute(u"name", self.hostname)
        elem.appendChild(hostname_elem)
        return frag

class HostOsAddHunk(DiffHunk):
    def __init__(self, os):
        self.os = os

    def to_string(self):
        return u"Add OS \"%s\"." % self.os

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-os-add")
        elem.setAttribute(u"name", self.os)
        frag.appendChild(elem)
        return frag

class HostOsRemoveHunk(DiffHunk):
    def __init__(self, os):
        self.os = os

    def to_string(self):
        return u"Remove OS \"%s\"." % self.os

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"host-os-remove")
        elem.setAttribute(u"name", self.os)
        frag.appendChild(elem)
        return frag

class PortIdChangeHunk(DiffHunk):
    def __init__(self, a_spec, b_spec):
        self.a_spec = a_spec
        self.b_spec = b_spec

    def to_string(self):
        return u"Service on %s is now running on %s." % (Port.spec_to_string(self.a_spec), Port.spec_to_string(self.b_spec))

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"port-id-change")
        elem.setAttribute(u"a-portid", unicode(self.a_spec[0]))
        elem.setAttribute(u"a-protocol", self.a_spec[1])
        elem.setAttribute(u"b-portid", unicode(self.b_spec[0]))
        elem.setAttribute(u"b-protocol", self.b_spec[1])
        frag.appendChild(elem)
        return frag

class PortStateChangeHunk(DiffHunk):
    def __init__(self, spec, a_port, b_port):
        self.spec = spec
        self.a_port = a_port
        self.b_port = b_port

    def to_string(self):
        lines = []
        a_str = u"%s %s" % (self.a_port.service.to_string(), self.a_port.service.version_to_string())
        b_str = u"%s %s" % (self.b_port.service.to_string(), self.b_port.service.version_to_string())
        if self.a_port.state != Port.UNKNOWN:
            lines.append("-%s %s %s" % (Port.spec_to_string(self.a_port.spec), self.a_port.state, a_str))
        if self.b_port.state != Port.UNKNOWN:
            lines.append("+%s %s %s" % (Port.spec_to_string(self.b_port.spec), self.b_port.state, b_str))
        return u"\n".join(lines)

    def service_elem(service, document, name):
        """Create a service element."""
        elem = document.createElement(name)
        if service.name is not None:
            elem.setAttribute(u"name", service.name)
        if service.product is not None:
            elem.setAttribute(u"product", service.product)
        if service.version is not None:
            elem.setAttribute(u"version", service.version)
        if service.extrainfo is not None:
            elem.setAttribute(u"extrainfo", service.extrainfo)
        return elem
    service_elem = staticmethod(service_elem)

    def to_dom_fragment(self, document):
        frag = document.createDocumentFragment()
        elem = document.createElement(u"port-state-change")
        elem.setAttribute(u"portid", unicode(self.spec[0]))
        elem.setAttribute(u"protocol", self.spec[1])
        elem.setAttribute(u"a-state", self.a_port.state)
        elem.setAttribute(u"b-state", self.b_port.state)
        frag.appendChild(elem)
        if not self.a_port.service == self.b_port.service:
            elem.appendChild(self.service_elem(self.a_port.service, document, u"a-service"))
            elem.appendChild(self.service_elem(self.b_port.service, document, u"b-service"))
        return frag

def partition_port_state_changes(diff):
    """Partition a list of PortStateChangeHunks into equivalence classes
    based on the tuple (protocol, a_state, b_state). The partition is returned
    as a list of lists of hunks."""
    transitions = {}
    for hunk in diff:
        if not isinstance(hunk, PortStateChangeHunk):
            continue
        a_state = hunk.a_port.state
        b_state = hunk.b_port.state
        protocol = hunk.spec[1]
        transitions.setdefault((protocol, a_state, b_state), []).append(hunk)
    return transitions.values()

def consolidate_port_state_changes(diff, threshold = 0):
    """Return a list of list of PortStateChangeHunks, where each list
    contains hunks with the same partition and state change. A group of hunks is
    returned in the list of lists only when its length exceeds threshold. Any
    hunks moved to the list of lists are removed from diff in place. This is to
    avoid overwhelming the output with a flood of "is filtered, was unknown"
    messages."""
    partition = partition_port_state_changes(diff)
    consolidated = []
    for group in partition:
        if len(group) > threshold:
            for hunk in group:
                diff.remove(hunk)
            consolidated.append(group)
    return consolidated

class ScanDiff(object):
    """A complete diff of two scans. It is a container for two scans and the
    diff between them, which is a list of (scan, host_diff) tuples as returned
    by scan_diff."""
    def __init__(self, scan_a, scan_b):
        """Create a ScanDiff from the "before" scan_a and the "after" scan_b."""
        self.scan_a = scan_a
        self.scan_b = scan_b
        self.diff = scan_diff(scan_a, scan_b)

    def print_text(self, f = sys.stdout, verbose = False):
        """Print this diff in a human-readable text form."""
        if self.scan_a.start_date is not None:
            start_date_a_str = self.scan_a.start_date.ctime()
        else:
            start_date_a_str = u"<unknown date>"
        if self.scan_b.start_date is not None:
            start_date_b_str = self.scan_b.start_date.ctime()
        else:
            start_date_b_str = u"<unknown date>"
        print >> f, "%s -> %s" % (start_date_a_str, start_date_b_str)
        for host, h_diff in self.diff:
            print >> f, "%s:" % host.format_name()
            h_diff_copy = h_diff[:]
            cons_port_state_changes = consolidate_port_state_changes(h_diff_copy, PORT_STATE_CHANGE_CONSOLIDATION_THRESHOLD)
            for hunk in h_diff_copy:
                print >> f, u"\n".join(u"\t" + s for s in hunk.to_string().split(u"\n"));
            for group in cons_port_state_changes:
                a_state = group[0].a_port.state
                b_state = group[0].b_port.state
                protocol = group[0].spec[1]
                port_list = [hunk.spec[0] for hunk in group]
                port_list_string = render_port_list(port_list)
                if verbose or len(port_list_string) <= \
                   PORT_STATE_CHANGE_DOUBLE_CONSOLIDATION_CHAR_THRESHOLD:
                    if a_state == Port.UNKNOWN:
                        print >> f, u"\tThe following %d %s ports are %s:" % (len(port_list), protocol, b_state)
                    else:
                        print >> f, u"\tThe following %d %s ports changed state from %s to %s:" % (len(port_list), protocol, a_state, b_state)
                    print >> f, u"\t\t" + port_list_string
                else:
                    if a_state == Port.UNKNOWN:
                        print >> f, u"\t%d %s ports are %s." % (len(port_list), protocol, b_state)
                    else:
                        print >> f, u"\t%d %s ports changed state from %s to %s." % (len(port_list), protocol, a_state, b_state)

    def print_xml(self, f = sys.stdout):
        impl = xml.dom.minidom.getDOMImplementation()
        document = impl.createDocument(None, u"nmapdiff", None)
        root = document.documentElement
        scandiff_elem = document.createElement(u"scandiff")
        if self.scan_a.start_date is not None:
            a_start_date_str = unicode(int(time.mktime(self.scan_a.start_date.timetuple())))
            scandiff_elem.setAttribute(u"a-start", a_start_date_str)
        if self.scan_b.start_date is not None:
            b_start_date_str = unicode(int(time.mktime(self.scan_b.start_date.timetuple())))
            scandiff_elem.setAttribute(u"b-start", b_start_date_str)
        root.appendChild(scandiff_elem)
        for host, h_diff in self.diff:
            host_elem = document.createElement(u"host")
            scandiff_elem.appendChild(host_elem)
            for (address_type, address) in host.addresses:
                address_elem = document.createElement(u"address")
                address_elem.setAttribute(u"addrtype", address_type)
                address_elem.setAttribute(u"addr", address)
                host_elem.appendChild(address_elem)
            for hostname in host.hostnames:
                hostname_elem = document.createElement(u"hostname")
                hostname_elem.setAttribute(u"name", hostname)
                host_elem.appendChild(hostname_elem)
            for hunk in h_diff:
                host_elem.appendChild(hunk.to_dom_fragment(document))
        document.writexml(f, addindent = u"  ", newl = u"\n", encoding = "UTF-8")
        document.unlink()

def port_diff(a, b):
    """Diff two Ports. The return value is a list of DiffHunks."""
    diff = []
    if a.spec != b.spec:
        hunk = PortIdChangeHunk(a.spec, b.spec)
        diff.append(hunk)
    if not (a.state == b.state and a.service == b.service):
        hunk = PortStateChangeHunk(b.spec, a, b)
        diff.append(hunk)
    return diff

def addresses_diff(a, b):
    """Diff two lists of addresses. The return value is a list of DiffHunks."""
    diff = []
    a_addresses = set(a)
    b_addresses = set(b)
    for addrtype, addr in a_addresses - b_addresses:
        hunk = HostAddressRemoveHunk(addrtype, addr)
        diff.append(hunk)
    for addrtype, addr in b_addresses - a_addresses:
        hunk = HostAddressAddHunk(addrtype, addr)
        diff.append(hunk)
    return diff

def hostnames_diff(a, b):
    """Diff two lists of hostnames. The return value is a list of DiffHunks."""
    diff = []
    a_hostnames = set(a)
    b_hostnames = set(b)

    for hostname in a_hostnames - b_hostnames:
        hunk = HostHostnameRemoveHunk(hostname)
        diff.append(hunk)
    for hostname in b_hostnames - a_hostnames:
        hunk = HostHostnameAddHunk(hostname)
        diff.append(hunk)
    return diff

def host_diff(a, b):
    """Diff two Hosts. The return value is a list of DiffHunks."""
    diff = []
    if a.state != b.state:
        hunk = HostStateChangeHunk(a.state, b.state)
        diff.append(hunk)

    diff.extend(addresses_diff(a.addresses, b.addresses))
    diff.extend(hostnames_diff(a.hostnames, b.hostnames))

    all_specs = list(set(a.ports.keys()).union(set(b.ports.keys())))
    all_specs.sort()
    for spec in all_specs:
        diff.extend(port_diff(a.ports[spec], b.ports[spec]))

    for os in set(a.os) - set(b.os):
        diff.append(HostOsRemoveHunk(os))
    for os in set(b.os) - set(a.os):
        diff.append(HostOsAddHunk(os))

    return diff

def scan_diff(a, b):
    """Diff two scans. The return value is a list of tuples. Each tuple has the
    form (scan, host_diff), where scan is either a or b, whichever is present in
    one of the scans (preferring a if both are present), and host_diff is the
    host diff as returned by host_diff."""
    diff = []
    a_hosts = dict((host.get_id(), host) for host in a.hosts)
    b_hosts = dict((host.get_id(), host) for host in b.hosts)
    all_host_ids = set(a_hosts.keys()).union(set(b_hosts.keys()))
    for id in all_host_ids:
        host_a = a_hosts.get(id)
        host_b = b_hosts.get(id)
        h_diff = host_diff(host_a or Host(), host_b or Host())
        if len(h_diff) > 0:
            diff.append((host_a or host_b, h_diff))
    return diff

def warn(str):
    """Print a warning to stderr."""
    print >> sys.stderr, str

def parse_port_list(port_list):
    """Parse a port list like
    "1-1027,1029-1033,1040,1043,1050,1058-1059,1067-1068,1076,1080,1083-1084"
    into a list of numbers. Raises ValueError if the port list is somehow
    invalid. This function and parse_port_list are rough inverses."""
    result = set()
    if port_list == u"":
        return list(result)
    chunks = port_list.split(u",")
    for chunk in chunks:
        if u"-" in chunk:
            start, end = chunk.split(u"-")
            start = int(start)
            end = int(end)
            if start >= end:
                raise ValueError(u"In range %s, start must be less than end." % chunk)
        else:
            start = int(chunk)
            end = start
        for p in range(start, end + 1):
            result.add(p)
    result = list(result)
    result.sort()
    return result

def render_port_list(ports):
    """Render a list of numbers into a string like
    "1-1027,1029-1033,1040,1043,1050,1058-1059,1067-1068,1076,1080,1083-1084".
    This function and parse_port_list are rough inverses."""
    ports = ports[:]
    ports.sort()
    chunks = []
    i = 0
    while i < len(ports):
        start = i
        i += 1
        while i < len(ports) and i - start == ports[i] - ports[start]:
            i += 1
        if i - start == 1:
            chunks.append(u"%d" % ports[start])
        elif i - start == 1:
            chunks.append(u"%d" % ports[start])
            chunks.append(u"%d" % ports[i - 1])
        else:
            chunks.append(u"%d-%d" % (ports[start], ports[i - 1]))
    return u",".join(chunks)

class NmapContentHandler(xml.sax.handler.ContentHandler):
    """The xml.sax ContentHandler for the XML parser. It contains a Scan object
    that is filled in and can be read back again once the parse method is
    finished."""
    def __init__(self, scan):
        self.scan = scan

        # We keep a stack of the elements we've seen, pushing on start and
        # popping on end.
        self.element_stack = []

        self.scanned_ports = {}
        self.current_host = None
        self.current_extraports = []
        self.current_port = None

    def parent_element(self):
        """Return the name of the element containing the current one, or None if
        this is the root element."""
        if len(self.element_stack) == 0:
            return None
        return self.element_stack[-1]

    def startElement(self, name, attrs):
        """This method keeps track of element_stack. The real parsing work is
        done in startElementAux. This is to make it easy for startElementAux to
        bail out on error."""
        self.startElementAux(name, attrs)

        self.element_stack.append(name)

    def endElement(self, name):
        """This method keeps track of element_stack. The real parsing work is
        done in endElementAux."""
        self.element_stack.pop()

        self.endElementAux(name)

    def startElementAux(self, name, attrs):
        if name == u"nmaprun":
            assert self.parent_element() == None
            if attrs.has_key(u"start"):
                start_timestamp = int(attrs.get(u"start"))
                self.scan.start_date = datetime.datetime.fromtimestamp(start_timestamp)
        elif name == u"scaninfo":
            assert self.parent_element() == u"nmaprun"
            try:
                protocol = attrs[u"protocol"]
            except KeyError:
                warn(u"%s element missing the \"protocol\" attribute; skipping." % name)
                return
            try:
                services_str = attrs[u"services"]
            except KeyError:
                warn(u"%s element missing the \"services\" attribute; skipping." % name)
                return
            try:
                services = parse_port_list(services_str)
            except ValueError:
                warn(u"Services list \"%s\" cannot be parsed; skipping.")
                return
            assert protocol not in self.scanned_ports
            self.scanned_ports[protocol] = services
        elif name == u"host":
            assert self.parent_element() == u"nmaprun"
            self.current_host = Host()
            self.scan.hosts.append(self.current_host)
        elif name == u"status":
            assert self.parent_element() == u"host"
            assert self.current_host is not None
            try:
                state = attrs[u"state"]
            except KeyError:
                warn(u"%s element of host %s is missing the \"state\" attribute; assuming \"unknown\"." % (name, self.current_host.format_name()))
                return
            self.current_host.state = state
        elif name == u"address":
            assert self.parent_element() == u"host"
            assert self.current_host is not None
            try:
                addr = attrs[u"addr"]
            except KeyError:
                warn(u"%s element of host %s is missing the \"addr\" attribute; skipping." % (name, self.current_host.format_name()))
                return
            addrtype = attrs.get(u"addrtype", u"ipv4")
            self.current_host.add_address(addrtype, addr)
        elif name == u"hostname":
            assert self.parent_element() == u"hostnames"
            assert self.current_host is not None
            try:
                hostname = attrs[u"name"]
            except KeyError:
                warn(u"%s element of host %s is missing the \"name\" attribute; skipping." % (name, self.current_host.format_name()))
                return
            self.current_host.add_hostname(hostname)
        elif name == u"extraports":
            assert self.parent_element() == u"ports"
            assert self.current_host is not None
            try:
                state = attrs[u"state"]
            except KeyError:
                warn(u"%s element of host %s is missing the \"state\" attribute; assuming \"unknown\"." % (name, self.current_host.format_name()))
                state = Port.UNKNOWN
            if state in self.current_extraports:
                warn(u"Duplicate extraports state \"%s\" in host %s." % (state, self.current_host.format_name()))
            # Perhaps check for count == 0.
            self.current_extraports.append(state)
        elif name == u"port":
            assert self.parent_element() == u"ports"
            assert self.current_host is not None
            try:
                portid_str = attrs[u"portid"]
            except KeyError:
                warn(u"%s element of host %s missing the \"portid\" attribute; skipping." % (name, self.current_host.format_name()))
                return
            try:
                portid = int(portid_str)
            except ValueError:
                warn(u"Can't convert portid \"%s\" to an integer in host %s; skipping port." % (portid_str, self.current_host.format_name()))
                return
            try:
                protocol = attrs[u"protocol"]
            except KeyError:
                warn(u"%s element of host %s missing the \"protocol\" attribute; skipping." % (name, self.current_host.format_name()))
                return
            self.current_port = Port((portid, protocol))
        elif name == u"state":
            assert self.parent_element() == u"port"
            assert self.current_host is not None
            if self.current_port is None:
                return
            if not attrs.has_key(u"state"):
                warn(u"%s element of port %s is missing the \"state\" attribute; assuming \"unknown\"." % (name, Port.spec_to_string(self.current_port.spec)))
                return
            self.current_port.state = attrs[u"state"]
            self.current_host.add_port(self.current_port)
        elif name == u"service":
            assert self.parent_element() == u"port"
            assert self.current_host is not None
            if self.current_port is None:
                return
            self.current_port.service.name = attrs.get(u"name")
            self.current_port.service.product = attrs.get(u"product")
            self.current_port.service.version = attrs.get(u"version")
            self.current_port.service.extrainfo = attrs.get(u"extrainfo")
        elif name == u"osmatch":
            assert self.parent_element() == u"os"
            assert self.current_host is not None
            if not attrs.has_key(u"name"):
                warn(u"%s element of host %s is missing the \"name\" attribute; skipping." % (name, self.current_host.format_name()))
                return
            self.current_host.os.append(attrs[u"name"])
        elif name == u"finished":
            assert self.parent_element() == u"runstats"
            if attrs.has_key(u"time"):
                end_timestamp = int(attrs.get(u"time"))
                self.scan.end_date = datetime.datetime.fromtimestamp(end_timestamp)

    def endElementAux(self, name):
        if name == u"nmaprun":
            self.scanned_ports = None
        elif name == u"host":
            if len(self.current_extraports) == 1:
                extraports_state = self.current_extraports[0]
                known_ports = self.current_host.get_known_ports()
                known_specs = set(port.spec for port in known_ports)
                for protocol in self.scanned_ports:
                    for portid in self.scanned_ports[protocol]:
                        spec = portid, protocol
                        if spec in known_specs:
                            continue
                        assert self.current_host.ports[spec].state == Port.UNKNOWN
                        self.current_host.add_port(Port(spec, state = extraports_state))

            self.current_host = None
            self.current_extraports = []
        elif name == u"port":
            self.current_port = None

def usage():
    print u"""\
Usage: %s [option] FILE1 FILE2
Compare two Nmap XML files and display a list of their differences.
Differences include host state changes, port state changes, and changes to
service and OS detection.

  -h, --help     display this help
  -v, --verbose  don't consolidate long port lists into just a count
  --text         display output in text format (default)
  --xml          display output in XML format\
""" % sys.argv[0]

def usage_error(msg):
    print >> sys.stderr, u"%s: %s" % (sys.argv[0], msg)
    print >> sys.stderr, u"Try '%s -h' for help." % sys.argv[0]
    sys.exit(1)

def main():
    output_format = None
    verbose = False

    try:
        opts, input_filenames = getopt.gnu_getopt(sys.argv[1:], "hv", ["help", "text", "verbose", "xml"])
    except getopt.GetoptError, e:
        print >> sys.stderr, u"%s: %s" % (sys.argv[0], e.msg)
        sys.exit(1)
    for o, a in opts:
        if o == "-h" or o == "--help":
            usage()
            sys.exit(0)
        elif o == "-v" or o == "--verbose":
            verbose = True
        elif o == "--text":
            if output_format is not None and output_format != "text":
                usage_error("contradictory output format options.")
            output_format = "text"
        elif o == "--xml":
            if output_format is not None and output_format != "xml":
                usage_error("contradictory output format options.")
            output_format = "xml"

    if len(input_filenames) != 2:
        usage_error("need exactly two input filenames.")

    if output_format is None:
        output_format = "text"

    filename_a = input_filenames[0]
    filename_b = input_filenames[1]

    scan_a = Scan()
    scan_a.load_from_file(filename_a)
    scan_b = Scan()
    scan_b.load_from_file(filename_b)

    diff = ScanDiff(scan_a, scan_b)

    if output_format == "text":
        diff.print_text(verbose = verbose)
    elif output_format == "xml":
        diff.print_xml()

if __name__ == "__main__":
    main()
