"""This module contains the code required to create and run
measurements for the 'Forensic Investigation of the OneSwarm Anonymous
Filesharing System' paper in CCS 2011.

Code assumes it is being run on the initiator, and that all one-time
setup (lighttpd configuration, putting the named pipe 'backpipe' into
place, etc.) is done beforehand, as per README.txt.

Copyright 2011 Marc Liberatore <liberato@cs.umass.edu>
"""

# library imports
import datetime
import getpass
import logging
import os
import time
import socket
import sys

# local imports
import pexpect
import pxssh

class MeasurementConfiguration(object):
    def __init__(self,
                 ah='boston.cs.ucsb.edu',
                 au='guest',
                 awp=8080,
                 arp=8081,
                 dh='mirror.mojohost.com',
                 ip=None,
                 ap=None):
        self.num_pings = 5
        self.initiator_hostname = socket.getfqdn()
        self.initiator_addr = socket.gethostbyname(self.initiator_hostname)
        if ip == None:
            ip = getpass.getpass('local password (for sudo): ')
        self.initiator_password = ip
        self.initiator_interface = 'eth0'
        self.relay_command_template = 'nc -l %d 0<backpipe | nc mirror.mojohost.com %d 1>backpipe' # NOTE: may need alternate below depending upon intermediate nc type
        # self.relay_command_template = 'nc -l -p %d 0<backpipe | nc mirror.mojohost.com %d 1>backpipe'
        self.adjacent_hostname = ah
        self.adjacent_addr = socket.gethostbyname(self.adjacent_hostname)
        self.adjacent_username = au
        if ap == None:
            ap = getpass.getpass('adjacent password/phrase: ')
        self.adjacent_password = ap
        self.adjacent_webserver_port = awp # NOTE: hardcoded in config file in lighttpd
        self.adjacent_relay_port = arp
        self.distant_hostname = dh
        self.distant_webserver_port = 80
        self.remote_path = '/centos/5.6/isos/x86_64/CentOS-5.6-x86_64-netinstall.iso'
        self.attack_maxaug = 50
        self.attack_windowscaling = 10
        self.attack_timeout = 3000 # ms, used as option for attack program
        self.attack_command_timeout = 120 # s, used by expect() in case attack program freezes

        self.directory = None # will be initialized by create_measurement_directory()
        self.adjacent_webserver_shell = None # will be initialized by start_webserver()
        
    def create_measurement_directory(self):
        """Create directory to store results from a set of
    measurements.  Updates the configuration with the created
    directory."""
        dirname = '%s-%s-%s-%s' % (self.initiator_hostname,
                                   self.adjacent_hostname,
                                   self.distant_hostname,
                                   datetime.datetime.now().isoformat())
        logging.info('creating measurement directory: %s' % dirname)

        assert not os.access(dirname, os.F_OK)

        os.mkdir(dirname)
        assert os.access(dirname, os.F_OK)

        self.directory = dirname

    def record_pings(self):
        """Record a series of pings between each host in the
        measurement configuration."""

        logging.info('starting pings')

        # initiator to others
        self.record_ping(self.initiator_hostname, self.adjacent_hostname)
        self.record_ping(self.initiator_hostname, self.distant_hostname)

        # adjacent to endpoint
        self.record_ping(self.adjacent_hostname, self.distant_hostname)

        logging.info('pings complete')
        
        
    def record_ping(self, pinger, pingee):
        """Worker method called by record_pings()."""
        outfilepath = os.path.join(self.directory,
                                   'ping-%s-%s.txt' % 
                                   (pinger, pingee))
        command = 'ping -c %d %s' % (self.num_pings, pingee)
        self.record_command(pinger, outfilepath, command)
        

    def record_command(self, host, outfilepath, command):
        """Worker method called to run commands and record their output."""
        
        # two cases: local or remote execution
        if host == self.initiator_hostname:
            # local
            logging.info('running command: %s' % command)
            try:
                result = pexpect.run(command)
            except pexpect.ExceptionPexpect:
                result = "tracepath not found"
            logging.info('saving results to %s' % outfilepath)
            resultlines = result.splitlines()
            with open(outfilepath, 'w') as outfile:
                outfile.write(command)
                outfile.write('\n')
                for line in resultlines:
                    outfile.write(line)
                    outfile.write('\n')
        
        elif host == self.adjacent_hostname:
            # remote
            try:
                logging.info('sshing to %s with username %s' % 
                             (host,
                              self.adjacent_username))
                s = pxssh.pxssh()
                # XXX KLUDGE ALERT XXX
                if host != 'isectestbed.uta.edu':
                    s.force_password = True
                s.login(host,
                        self.adjacent_username,
                        self.adjacent_password)

                logging.info('running command: %s' % command)
                s.sendline(command)
                s.prompt()
                resultlines = s.before.splitlines()
                logging.info('saving results %s' % outfilepath)
                with open(outfilepath, 'w') as outfile:
                    for line in resultlines:
                        outfile.write(line)
                        outfile.write('\n')
            except pxssh.ExceptionPxssh as e:
                logging.error('pxssh failed! (%s)' % str(e))

        else:
            raise AssertionError('bad host configuration')

    def record_paths(self):
        """Record a set of tracepaths (tracepath is linux-specific;
        traceroute not yet implemented) between relevant hosts in this
        measurement configuration."""
        
        logging.info('starting tracepaths')

        # initiator to others
        self.record_path(self.initiator_hostname, self.adjacent_hostname)
        self.record_path(self.initiator_hostname, self.distant_hostname)

        # adjacent to endpoint
        self.record_path(self.adjacent_hostname, self.distant_hostname)

        logging.info('tracepaths complete')

    def record_path(self, startpoint, endpoint):
        """Worker method called by record_path()."""
        outfilepath = os.path.join(self.directory,
                                   'tracepath-%s-%s.txt' % 
                                   (startpoint, endpoint))
        command = 'tracepath %s' % endpoint
        self.record_command(startpoint, outfilepath, command)

    def start_webserver(self):
        """Start the web server on the adjacent host, and bind
        self.adjacent_webserver_shell to this shell."""
        kill_command = 'killall lighttpd'
        start_command = 'lighttpd-install/sbin/lighttpd -D -f lighttpd.conf'
        
        try:
            logging.info('sshing to %s with username %s' % 
                         (self.adjacent_hostname,
                          self.adjacent_username))
            s = pxssh.pxssh()
            # XXX KLUDGE ALERT XXX
            if self.adjacent_hostname != 'isectestbed.uta.edu':
                s.force_password = True
            s.login(self.adjacent_hostname,
                    self.adjacent_username,
                    self.adjacent_password)            

            logging.info('running command: %s' % kill_command)
            s.sendline(kill_command)
            s.prompt()

            logging.info('running command: %s' % start_command)
            s.sendline(start_command)

            logging.info('web server started on host %s' %
                         (self.adjacent_hostname, ))

            self.adjacent_webserver_shell = s
        except pxssh.ExceptionPxssh as e:
            sys.exit('pxssh failed! (%s)' % str(e))
            

    def stop_webserver(self):
        """Stops the web server on the adjacent host, closes the ssh connection,
        and unbinds self.adjacent_webserver_shell."""
        try:
            logging.info('killing web server on %s' % 
                         self.adjacent_hostname)
            s = self.adjacent_webserver_shell

            s.sendintr()
            s.prompt()

            logging.info('web server on %s killed' % self.adjacent_hostname)

            logging.info('logging out of %s' % self.adjacent_hostname)

            forceClose = False
            
            try:
                s.logout()
            except pexpect.TIMEOUT:
                forceClose = True

            if forceClose:
                try:
                    s.close(force=True)
                except:
                    pass
            self.adjacent_webserver_shell = None
        except pxssh.ExceptionPxssh as e:
            logging.error('pxssh failed! (%s)' % str(e))
        except pexpect.TIMEOUT as e:
            logging.error('timeout on webserver shell close (%s)' % str(e))


    def _sudo_command(self, command, expected):
        """Run a sudoed command on the local machine; assume the sudo
        is part of the provided command, expect the expected value,
        and return the spawned pexpect shell."""
        pexpect.run('sudo -k')
        shell = pexpect.spawn(command)
        shell.expect('password')
        shell.waitnoecho()
        shell.sendline(self.initiator_password)
        shell.expect(expected)
        return shell
        

    def insert_iptables_rule(self, remote_port):
        """Insert the appropriate iptables rule to allow the attack
        program to run (by dropping RSTs that the kernel's TCP
        implementation sends when receiving packets on a port it
        doesn't know is running our version of TCP.
        """

        iptables_command = 'sudo /sbin/iptables -A OUTPUT -p tcp --tcp-flags RST RST -s %s -d %s --dport %d -j DROP' % \
            (self.initiator_addr,
             self.adjacent_addr,
             remote_port)
                
        logging.info('running command: %s' % iptables_command)
        self._sudo_command(iptables_command,
                           [pexpect.EOF, pexpect.TIMEOUT])

    def delete_iptables_rule(self, remote_port):
        """Delete the rule inserted by insert_iptables_rule.
        """

        iptables_command = 'sudo /sbin/iptables -D OUTPUT -p tcp --tcp-flags RST RST -s %s -d %s --dport %d -j DROP' % \
            (self.initiator_addr,
             self.adjacent_addr,
             remote_port)
                
        logging.info('running command: %s' % iptables_command)
        self._sudo_command(iptables_command,
                           [pexpect.EOF, pexpect.TIMEOUT])

    def open_relay_shell(self):
        """Open and bind self.adjacent_relay_shell to an ssh shell to
        the adjacent host."""
        try:
            logging.info('sshing to %s with username %s' % 
                         (self.adjacent_hostname,
                          self.adjacent_username))
            # XXX KLUDGE ALERT XXX
            s = pxssh.pxssh()
            if self.adjacent_hostname != 'isectestbed.uta.edu':
                s.force_password = True
            s.login(self.adjacent_hostname,
                    self.adjacent_username,
                    self.adjacent_password)            

            self.adjacent_relay_shell = s
        except pxssh.ExceptionPxssh as e:
            sys.exit('pxssh failed! (%s)' % str(e))

    def start_relay(self):
        """Start the relay (proxy) on the adjacent host."""
        try:
            start_command = self.relay_command_template % \
                (self.adjacent_relay_port, self.distant_webserver_port)
            logging.info('running command: %s' % start_command)
            self.adjacent_relay_shell.sendline(start_command)

            logging.info('relay started on host %s' %
                             (self.adjacent_hostname, ))
        except pxssh.ExceptionPxssh as e:
            sys.exit('pxssh failed! (%s)' % str(e))

    def finish_relay(self):
        """Wait for the relay (proxy) on the adjacent host to end."""
        try:
            logging.info('awaiting end of relay process on %s' % 
                         self.adjacent_hostname)
            self.adjacent_relay_shell.sendintr()
            self.adjacent_relay_shell.sendline()
            self.adjacent_relay_shell.prompt()
            logging.info('relay on %s finished' % self.adjacent_hostname)

        except pxssh.ExceptionPxssh as e:
            logging.error('pxssh failed! (%s)' % str(e))
            
    def close_relay_shell(self):
        """Closes the ssh connection to the adjacent host and unbinds
        self.adjacent_relay_shell."""
        try:
            s = self.adjacent_relay_shell
            s.prompt()

            logging.info('logging out of %s' % self.adjacent_hostname)
            s.logout()

            self.adjacent_relay_shell = None
        except pxssh.ExceptionPxssh as e:
            logging.error('pxssh failed! (%s)' % str(e))
        except pexpect.TIMEOUT as e:
            logging.error('timeout on relay logout.')

    def _wget_retrieval(self, wget_fname, i, remote_port):
        tcpdump_wget_fname = wget_fname + '-%d.pcap' % i

        tcpdump_command = 'sudo /usr/sbin/tcpdump -s 68 -i %s -w %s (src %s and src port %d) or (dst %s and dst port %d)' % \
                          (self.initiator_interface,
                           tcpdump_wget_fname,
                           self.adjacent_addr,
                           remote_port,
                           self.adjacent_addr,
                           remote_port)
        logging.info('running command: %s' % tcpdump_command)
        tcpdump_shell = self._sudo_command(tcpdump_command, 'tcpdump: listening')

        wget_log_fname = wget_fname + '-%d.log' % i
        wget_command = 'wget -O /dev/null http://%s:%d%s' % \
                       (self.adjacent_hostname,
                        remote_port,
                        self.remote_path)

        logging.info('running command: %s' % wget_command)

        start_time = time.time()
        result = pexpect.run(wget_command, timeout=120)
        end_time = time.time()

        logging.info('saving results to %s' % wget_log_fname)
        resultlines = result.splitlines()
        with open(wget_log_fname, 'w') as outfile:
            outfile.write(wget_command)
            outfile.write('\n')
            for line in resultlines:
                outfile.write(line)
                outfile.write('\n')
            outfile.write('started: %f\n' % start_time)
            outfile.write('finished: %f\n' % end_time)
            outfile.write('elapsed: %f\n' % (end_time - start_time))

        logging.info('closing tcpdump')
        tcpdump_shell.sendintr()
        time.sleep(1)
        tcpdump_shell.close()

    def _attack_retrieval(self, attack_fname, i, remote_port):
        tcpdump_attack_fname = attack_fname + '-%d.pcap' % i
        tcpdump_command = 'sudo /usr/sbin/tcpdump -s 68 -i %s -w %s (src %s and src port %d) or (dst %s and dst port %d)' % \
            (self.initiator_interface,
             tcpdump_attack_fname,
             self.adjacent_addr,
             remote_port,
             self.adjacent_addr,
             remote_port)

        logging.info('running command: %s' % tcpdump_command)
        tcpdump_shell = self._sudo_command(tcpdump_command, 'tcpdump: listening')

        attack_log_fname = attack_fname + '-%d.log' % i
        attack_command = 'sudo ./attack -s %s -d %s -p %d -a %d -w %d -D %d -g %s' % \
                         (self.initiator_addr,
                          self.adjacent_addr,
                          remote_port,
                          self.attack_maxaug,
                          self.attack_windowscaling,
                          self.attack_timeout,
                          self.remote_path)
        logging.info('running command: %s' % attack_command)

        pexpect.run('sudo -k')
        attack_shell = pexpect.spawn(attack_command, timeout=self.attack_command_timeout)
        attack_shell.expect('password')
        attack_shell.waitnoecho()
        attack_shell.sendline(self.initiator_password)
        start_time = time.time()        
        result = attack_shell.expect(['Exit', pexpect.TIMEOUT, pexpect.EOF])
        end_time = time.time()
        try:
            attack_shell.close(True)
        except pexpect.ExceptionPexpect:            
            killcmd = 'sudo kill -9 %d' % attack_shell.pid
            logging.info('killing stuck attack command: %s' % killcmd)
            self._sudo_command(killcmd, [pexpect.TIMEOUT, pexpect.EOF])
        
        resultlines = attack_shell.before.splitlines()
        if result == 0:
            # 'Exit'
            pass
        elif result == 1:
            resultlines.append('TIMEOUT')
        elif result == 2:
            resultlines.append('EOF')
        else:
            raise AssertionError('bad pexpect case')

        with open(attack_log_fname, 'w') as outfile:
            outfile.write(attack_command)
            outfile.write('\n')
            for line in resultlines:
                outfile.write(line)
                outfile.write('\n')
            outfile.write('started: %f\n' % start_time)
            outfile.write('finished: %f\n' % end_time)
            outfile.write('elapsed: %f\n' % (end_time - start_time))

        logging.info('closing tcpdump')
        tcpdump_shell.sendintr()
        time.sleep(1)
        tcpdump_shell.close()
        

    def _wget_direct(self, i):
        wget_fname = os.path.join(self.directory,
                                  'direct-wget-%s-%s' % (self.initiator_hostname,
                                                          self.adjacent_hostname))
        self._wget_retrieval(wget_fname, i, self.adjacent_webserver_port)

    def _attack_direct(self, i):
        attack_fname = os.path.join(self.directory,
                                    'direct-attack-%s-%s' % (self.initiator_hostname,
                                                              self.adjacent_hostname))
        self._attack_retrieval(attack_fname, i, self.adjacent_webserver_port)
                     
    def record_direct_retrievals(self, num_runs=10):
        """Configure the adjacent host to serve the file directly, and
        record num_runs retrievals of the file of interest.  Record
        tcpdump logs locally, as well as the output from wget and our
        attack program."""

        self.start_webserver()
        self.insert_iptables_rule(self.adjacent_webserver_port)
        for i in range(num_runs):
            self._wget_direct(i)
            self._attack_direct(i)
        self.delete_iptables_rule(self.adjacent_webserver_port)
        self.stop_webserver()


        
    def _wget_relay(self, i):
        wget_fname = os.path.join(self.directory,'relay-wget-%s-%s-%s' % \
                                  (self.initiator_hostname,
                                    self.adjacent_hostname,
                                    self.distant_hostname))
        self.start_relay()

        self._wget_retrieval(wget_fname, i, self.adjacent_relay_port)

        self.finish_relay()

    def _attack_relay(self, i):
        attack_fname = os.path.join(self.directory, 'relay-attack-%s-%s-%s' % \
                                    (self.initiator_hostname,
                                      self.adjacent_hostname,
                                      self.distant_hostname))
        self.start_relay()
        
        self._attack_retrieval(attack_fname, i, self.adjacent_relay_port)

        self.finish_relay()

    def record_indirect_retrievals(self, num_runs=10):
        """Configure the adjacent host to relay the file from the
        distant host, and record num_runs retrievals of the file of
        interest.  Record tcpdump logs locally, as well as the output
        from wget and our attack program."""

        self.open_relay_shell()
        self.insert_iptables_rule(self.adjacent_relay_port)
        for i in range(num_runs):
            self._wget_relay(i)
            self._attack_relay(i)
        self.delete_iptables_rule(self.adjacent_relay_port)
        self.close_relay_shell()


    def run_retrievals(self, i=10):
        """Run i retrievals."""
        self.create_measurement_directory()
        self.record_pings()
        self.record_paths()
        self.record_direct_retrievals(i)
        self.record_indirect_retrievals(i)


def test():
    logging.basicConfig(level=logging.DEBUG)
    m = MeasurementConfiguration('isectestbed.uta.edu',
                                 'marc',
                                 8080,
                                 8081)
    # m.relay_command_template = 'nc -l -p %d 0<backpipe | nc mirror.mojohost.com %d 1>backpipe'
    m.run_retrievals(1)

def all_measurements():
    initiator = socket.getfqdn()
    if initiator == 'isectestbed':
        initiator = 'isectestbed.uta.edu'
        
    relays = {'prismslab.cs.umass.edu':'liberato',
              'isectestbed.uta.edu':'marc',
              'kurtz.cs.wesleyan.edu':'mliberatore',
              'boston.cs.ucsb.edu':'guest'
              }
    distants = ['mirror.hmc.edu',
                'mirror.mojohost.com',
                'mirrors.cmich.edu'
                ]
    try:        
        import passwords # NOTE: you'll need your own relays and corresponding passwords :)
    except ImportError:
        sys.exit("You need to define your own passwords.py file (or don't call all all_measurments() directly).")                  
    mcs = []
    local_pw = passwords.password[initiator]
    for relay,username in relays.iteritems():
        if relay == initiator:
            continue
        remote_pw = passwords.password[relay]
        for distant in distants:
            m = MeasurementConfiguration(relay,
                                         username,
                                         dh=distant,
                                         ip=local_pw,
                                         ap=remote_pw)
            if relay == 'kurtz.cs.wesleyan.edu':
                m.relay_command_template = 'nc -l -p %d 0<backpipe | nc mirror.mojohost.com %d 1>backpipe'
            mcs.append(m)
    for m in mcs:
        m.run_retrievals(5)
    
    
if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    all_measurements()
