#!/usr/bin/env python

import os
import sys
import subprocess
import pexpect
import re
import time
import signal
import threading

debug = False
#debug = True
regex_exitstatus = re.compile(r'^(.*)EXITSTATUS_CLIENT:(\s+)(\d+)(\s+)EXITSTATUS_SERVER:(\s+)(\d+)(\s+)?$')
#regex_exitstatus = re.compile(r'EXITSTATUS_CLIENT:(\s+)(\d+)(\s+)EXITSTATUS_SERVER:(\s+)(\d+)(\s+)?$')
#regex_exitstatus = re.compile(r'EXITSTATUS_CLIENT:(\s+)(\d+)(\s+)EXITSTATUS_SERVER:(\s+)(\d+)$')
#regex_exitstatus = re.compile(r'EXITSTATUS_CLIENT:')

test_seq = 0
portno_start = 49917

def mk_client_cmd(base_dir,log_dir,run_script,cipher_string,hw_eng_client,hw_eng_serv,example_server,example_client,client_id,log_prefix,portno,use_atmel_ca=1,new_key=0):
   print log_prefix
   fname_log = '%s/%s_%s_clnt_%0d_hwc_%0d_hws_%0d_exsrv_%0d_exclnt_%0d_%s.log' % (log_dir,log_prefix,cipher_string,client_id,hw_eng_client,hw_eng_serv,example_server,example_client,run_script)
   #cmd_client = '(%s/run_client.sh 2>&1) | tee > %s' % (base_dir,fname_log)
   #cmd_client = '%s/run_client.sh 2>&1 | tee > %s' % (base_dir,fname_log)
   #cmd_client = '(%s/run_client.sh 2>&1) ' % (base_dir)
   #cmd_client = '%s/run_client.sh 2>&1 ' % (base_dir)
   #cmd_client = '%s/run_client.sh 2>&1 | tee %s ' % (base_dir,fname_log)

   # Works
   #cmd_client = '%s/run_client.sh 2>&1 > %s' % (base_dir,fname_log)
   env_vars = {}
   if re.search('_der_',run_script):
      run_script = re.sub(r'_der_','_',run_script)
      env_vars['USE_ATMEL_CA'] = '%d' % (use_atmel_ca)

   cmd_client = '%s/%s 2>&1' % (base_dir,run_script)
   #cmd_client = '%s/%s >> %s 2>&1' % (base_dir,fname_log,run_script)

   print '*** cmd_client: %s, hwc=%0d hws=%0d ***' % (cmd_client,hw_eng_client,hw_eng_serv)
   env_vars['PORT_NUMBER'] = '%d' % (portno)
   env_vars['USE_ENGINE'] = '%d' % (hw_eng_client)
   env_vars['SSL_CIPHER'] = '%s' % (cipher_string)
   env_vars['NEW_KEY'] = '%d' % (new_key)
   if cipher_string in ciphers_rsa:
      env_vars['USE_RSA'] = '1'
   env_vars['USE_EXAMPLE'] = '%d' % (example_client)
   
   parm_str = ''
   for (var,val) in env_vars.items():
      parm_str += '%s=%s ' % (var,val)
   rerun_str = '%s %s' % (parm_str,cmd_client)
   print '#RERUN: %s' % (rerun_str)
   my_env = merge_dicts(os.environ.copy(),env_vars)

   return (cmd_client,fname_log,my_env)

def mk_server_cmd(base_dir,log_dir,run_script,cipher_string,hw_eng_client,hw_eng_serv,example_server,example_client,log_prefix,portno,use_atmel_ca=1,new_key=0):
   fname_log = '%s/%s_%s_hwc_%0d_hws_%0d_exsrv_%0d_exclnt_%0d_%s.log' % (log_dir,log_prefix,cipher_string,hw_eng_client,hw_eng_serv,example_server,example_client,run_script)
   #cmd_server = '(%s/run_server.sh 2>&1) > %s' % (base_dir,fname_log)
   #cmd_server = '%s/run_server.sh 2>&1 > %s' % (base_dir,fname_log)
   env_vars = {}
   if re.search('_der_',run_script):
      run_script = re.sub(r'_der_','_',run_script)
      env_vars['USE_ATMEL_CA'] = '%d' % (use_atmel_ca)

   cmd_server = '%s/%s 2>&1' % (base_dir,run_script)
#   cmd_server = '%s/%s >> %s 2>&1' % (base_dir,fname_log,run_script)

   print '*** cmd_server: %s, hwc=%0d hws=%0d ***' % (cmd_server,hw_eng_client,hw_eng_serv)
   env_vars['PORT_NUMBER'] = '%d' % (portno)
   env_vars['USE_ENGINE'] = '%d' % (hw_eng_serv)
   env_vars['SSL_CIPHER'] = '%s' % (cipher_string)
   env_vars['NEW_KEY'] = '%d' % (new_key)
   if cipher_string in ['DH-RSA-AES256-SHA256']: # requires TARGET="dh"on server side
      env_vars['TARGET'] = 'dh'
   if cipher_string in ciphers_rsa:
      env_vars['USE_RSA'] = '1'
   env_vars['USE_EXAMPLE'] = '%d' % (example_server)

   parm_str = ''
   for (var,val) in env_vars.items():
      parm_str += '%s=%s ' % (var,val)
   rerun_str = '%s %s' % (parm_str,cmd_server)
   print '#RERUN: %s' % (rerun_str)
   my_env = merge_dicts(os.environ.copy(),env_vars)

   return (cmd_server,fname_log,my_env)

def create_test(client_cmd_lst,cmd_server,env_server,fname_log_server,gen_cert,gen_rsa_cert,script_prefix,hw_cert,cert_test,atmel_ca):
   global test_seq
   script_name = '%s.py' % (script_prefix)

   tmpfile = 'tmp.py'
   with open(tmpfile,'w') as f_test:
#      f_test.write('if __name__ == "__main__":\n')
      f_test.write('\tclient_cmd_lst = [\n')
      for (cmd_client,fname_log_client,env_client) in client_cmd_lst:
         #f_test.write("\tcmd_client = '%s',\n" % (cmd_client))
         f_test.write("\t('%s','%s',%s)\n" % (cmd_client,fname_log_client,env_client))
      f_test.write('\t]\n')
      for (cmd_client,fname_log_client,env_client) in client_cmd_lst:
         f_test.write("\tos.system('rm -f %s')\n" % (fname_log_client))
      f_test.write("\tcmd_server = '%s'\n" % (cmd_server))
      f_test.write("\tenv_server = %s\n" % (env_server))
      f_test.write("\tfname_log_server = '%s'\n" % (fname_log_server))
      f_test.write("\tos.system('rm -f %s')\n" % (fname_log_server))
      
      #f_test.write('\tgen_cert = %s\n' % (gen_cert))
      #f_test.write('\tgen_rsa_cert = %s\n' % (gen_rsa_cert))

      if atmel_ca:
         f_test.write("\trun_cert_cmd(base_dir,fname_log_server,None,cmd_cert='run_extract_certs.sh')\n")
         
      if cert_test:
         f_test.write("\trun_cert_cmd(base_dir,fname_log_server,'serv')\n")
         f_test.write("\trun_cert_cmd(base_dir,fname_log_server,'client')\n")

      if gen_cert:
         f_test.write("\trun_cert_cmd(base_dir,fname_log_server,'%s')\n" % (hw_cert))

      if gen_rsa_cert:
         f_test.write('\trun_rsa_cert(base_dir,fname_log_server)\n')

   test_seq += 1

   cmd = 'cat test_engine_pre.py %s test_engine_post.py > %s;chmod +x %s;rm -f %s' % (tmpfile,script_name,script_name,tmpfile)
   print cmd
   os.system(cmd)

   # Add to the test driver script
   if True or test_seq < 2:
      write_test_script(fname_drv,'%s/%s' % (base_dir,script_name))
   else:
      write_test_script(fname_drv,'#%s/%s' % (base_dir,script_name))

def create_tests(base_dir,log_dir,test_script_pairs,cipher_db,num_clients=1,test_engine_server=None,test_engine_client=None,test_example_server=None,test_example_client=None,valid_combos=None,test_atmel_ca_server=None,test_atmel_ca_client=None,cert_test=False,atmel_ca=False):
   global rsa_cert_generated
   global test_seq
   nxt_prev_eng = (0,0)
   test_db = {}
   portno = portno_start

   print 'TE: ',test_engine_server, test_engine_client

   for (run_scr_client,run_scr_server) in test_script_pairs:
      for engine_server in test_engine_server:
         for engine_client in test_engine_client:
            for example_server in test_example_server:
               for example_client in test_example_client:
                  for use_atmel_ca_server in test_atmel_ca_server:
                     for use_atmel_ca_client in test_atmel_ca_client:

                        combo = (example_server, engine_server, use_atmel_ca_server,
                                 example_client, engine_client, use_atmel_ca_client)

                        prev_eng = nxt_prev_eng

                        print 'combo: ',combo
                        print 'valid_combos: ',valid_combos
                        if not (combo in valid_combos):
                           print 'skipping: %s' % (str(combo))
                           nxt_prev_eng = (engine_server,engine_client)
                           continue
                        else:
                           print 'accepted: %s ' % (str(combo))

                        print 'test_engine_server: %d test_engine_client: %d' %(engine_server,engine_client)

                        gen_cert = False
                        hw_cert = None
                        if not cert_test:
                           if engine_server == 1 and prev_eng[0] == 0:
                              hw_cert = 'serv'
                              gen_cert = True
                           elif engine_client == 1 and prev_eng[1] == 0:
                              hw_cert = 'client'
                              gen_cert = True
                           else:
                              hw_cert = None
                              
                        print 'gen_cert: %s eng_srv: %d eng_clt: %d' % (gen_cert,engine_server,engine_client)

                        prev_eng = (engine_server,engine_client)
                        for cipher_string in cipher_db:
                           print '******************************************************************'
                           print '*** Cipher string: %s **********' % (cipher_string)
                           print '******************************************************************'

                           script_prefix = 'run_test_%0d' % (test_seq)
                           log_prefix = script_prefix

                           if re.search('RSA',cipher_string):
                              gen_rsa_cert = True
                           else:
                              gen_rsa_cert = False

                           (client_cmd_lst,fname_log_client_lst) = ([],[])
                           for i in range(num_clients):
                              (cmd_client,fname_log_client,env_client) = mk_client_cmd(base_dir,log_dir,run_scr_client,cipher_string,engine_client,engine_server,example_server,example_client,i,log_prefix,portno)
                              client_cmd_lst += [(cmd_client,fname_log_client,env_client)]
                              fname_log_client_lst += [fname_log_client]
                        
                           (cmd_server,fname_log_server,env_server) = mk_server_cmd(base_dir,log_dir,run_scr_server,cipher_string,engine_client,engine_server,example_server,example_client,log_prefix,portno)

                           create_test(client_cmd_lst,cmd_server,env_server,fname_log_server,gen_cert,gen_rsa_cert,script_prefix,hw_cert,cert_test,atmel_ca)
                           if True or test_seq < 2: # debug
                              test_db[(str(fname_log_client_lst),fname_log_server)] = (fname_log_client_lst,fname_log_server)
                     
                           nxt_prev_eng = (engine_server,engine_client)
                           portno += 1

   return test_db

def debug_errs(fname_debug,fname_log_client_lst,fname_log_server):
   with open(fname_debug,'a') as f_debug:
      f_debug.write('\n')
      f_debug.write('=======================\n')

   dstatus_client = 0
   for fname_log_client in fname_log_client_lst:
      #cmd = '(egrep -i "unknown command|Segmentation fault|error|Cipher is \(NONE\)|Connection refused" %s | egrep -v "verify error" 2>&1) >> %s' % (fname_log_client,fname_debug)
      cmd = '(egrep -i "Segmentation fault|error|Cipher is \(NONE\)|Connection refused" %s | egrep -v "verify error" 2>&1) >> %s' % (fname_log_client,fname_debug)
      with open(fname_debug,'a') as f_debug:
         f_debug.write('** client dbg **: %s\n' % (fname_log_client))
         #f_debug.write('%s\n' % (cmd))

      p_dclient = subprocess.Popen(cmd,shell=True)
      p_dclient.wait()
      if p_dclient.returncode == 0:
         dstatus_client |= 1
      else:
         dstatus_client |= 0

   #cmd = '(egrep -i "unknown command|Address already in use|error" %s 2>&1) >> %s' % (fname_log_server,fname_debug)
   cmd = '(egrep -i "Address already in use|error" %s 2>&1) >> %s' % (fname_log_server,fname_debug)
   with open(fname_debug,'a') as f_debug:
      f_debug.write('** server dbg **: %s\n' % (fname_log_server))
      #f_debug.write('%s\n' % (cmd))

   p_dserver = subprocess.Popen(cmd,shell=True)
   p_dserver.wait()
   if p_dserver.returncode == 0:
      dstatus_server = 1
   else:
      dstatus_server = 0
   with open(fname_debug,'a') as f_debug:
      f_debug.write('=======================\n')

   return (dstatus_client,dstatus_server)

def testing_summary(results,log_dir):
   print '*****************************************************'
   print '****************** TEST RESULTS *********************'
   print '*****************************************************'
   pass_client_cnt,fail_client_cnt = 0,0
   pass_server_cnt,fail_server_cnt = 0,0
   fname_debug = '%s/debug.log' % (log_dir)
   cmd = 'rm -f %s' % (fname_debug)
   print cmd
   os.system(cmd)
   for (key,(status_client,status_server,fname_log_client_lst,fname_log_server)) in results.items():
      (dstatus_client,dstatus_server) = debug_errs(fname_debug,fname_log_client_lst,fname_log_server)
      status_str = ''
      if status_client == 0 and dstatus_client == 0:
         status_str += 'PASS_CLIENT\t'
         pass_client_cnt += 1
      else:
         status_str += 'FAIL_CLIENT\t%0d %0d' % (status_client,dstatus_client)
         fail_client_cnt += 1

#      if status_server == 0:
#         status_str += 'PASS_SERVER\t'
#         pass_server_cnt += 1
#      else:
#         status_str += 'FAIL_SERVER\t'
#         fail_server_cnt += 1

      with open(fname_debug,'a') as f_debug:
         f_debug.write('%s:\t\t%s\n' % (key,status_str))
         f_debug.write('****************\n')
      
   cmd = 'cat %s | egrep "PASS|FAIL" | sort' % (fname_debug)
   #print cmd
   os.system(cmd)

   with open(fname_debug,'a') as f_debug:
      f_debug.write('*********************************\n')
      pass_str = 'CLIENT_PASS:\t%0d' % (pass_client_cnt)
      print pass_str
      f_debug.write('%s:\n' % (pass_str))

      fail_str = 'CLIENT_FAIL:\t%0d' % (fail_client_cnt)
      print fail_str
      f_debug.write('%s:\n' % (fail_str))
      f_debug.write('*********************************\n')

def merge_dicts(x, y):
    '''Given two dicts, merge them into a new dict as a shallow copy.'''
    z = x.copy()
    z.update(y)
    return z

def write_test_script(fname_script,text,init=False):
   do_chmod = not os.path.exists(fname_script)
   if init:
      mode = 'w'
   else:
      mode = 'a'
   with open(fname_script,mode) as f_out:
      f_out.write('%s\n' % (text))
      
   if do_chmod:
      cmd = 'chmod +x %s' % (fname_script)
      print cmd
      os.system(cmd)

def run_tests(fname_drv,test_db):
   cmd = '%s' % (fname_drv)
   print cmd
   os.system(cmd)
   results = {}
   print '**** run_tests() **** '
   for (key,(fname_log_client_lst,fname_log_server)) in test_db.items():
      (client_status,server_status) = (-1,-1)
      try:
         with open(fname_log_server,'r') as f_log:
            for line in f_log.readlines():
               line = line.strip()
               m = regex_exitstatus.match(line)
               if m:
                  client_status = int(m.group(3))
                  server_status = int(m.group(6))
                  break
               else:
                  continue
      except:
         print 'Could not open log file: %s' % (fname_log_server)
         pass

      results[(str(fname_log_client_lst),fname_log_server)] = (client_status,
                                                               server_status,
                                                               fname_log_client_lst,
                                                               fname_log_server)
   return results

if __name__ == "__main__":
   gen_only = False

   gen_rsa = True
   gen_ecdsa_0 = True
   gen_ecdsa_1 = True

   run_rsa = True
   run_ecdsa_0 = True
   run_ecdsa_1 = True

   if len(sys.argv) == 2:
      if re.search('gen',sys.argv[1]):
         gen_only = True
         run_rsa = False
         run_ecdsa_0 = False
         run_ecdsa_1 = False

   base_dir = '.'
   log_dir = 'log'
   fname_drv = '%s/runit.sh' % (base_dir)
   write_test_script(fname_drv,'#!/bin/sh',init=True)
   write_test_script(fname_drv,'set -x')
   write_test_script(fname_drv,'mkdir -p %s/%s' % (base_dir,log_dir))
   all_results = {}

   #
   # RSA test ciphers
   #
   ciphers_rsa = [
      'ECDHE-RSA-AES128-SHA',
      'ECDHE-RSA-AES128-SHA256',
      'ECDHE-RSA-AES128-GCM-SHA256',
      'ECDHE-RSA-AES256-SHA'
   ]

   cipher_db = ciphers_rsa
   test_script_pairs = [('run_client.sh','run_server.sh')]
   test_engine_server = [0,1]
   test_engine_client = [0,1]
   #test_example_server = [0,1]
   test_example_server = [0]
   #test_example_server = [1]
   #test_example_client = [0,1]
   test_example_client = [0]
   #test_example_client = [1]
   test_atmel_ca_server = [0]
   test_atmel_ca_client = [0]

   valid_combos = [
      (0,0,0,0,0,0),
      (1,0,0,0,0,0),
      (0,0,0,1,0,0),
      (1,0,0,1,0,0),
      (0,1,0,0,0,0),
      (1,1,0,0,0,0),
      (0,0,0,0,1,0),
      (0,0,0,1,1,0),
      (1,1,0,1,0,0),
      (1,0,0,1,1,0)
   ]

   num_clients = 1
   if gen_rsa:
      test_db = create_tests(base_dir,log_dir,test_script_pairs,cipher_db,test_engine_server=test_engine_server,test_engine_client=test_engine_client,test_example_server=test_example_server,test_example_client=test_example_client,num_clients=num_clients,valid_combos=valid_combos,test_atmel_ca_server=test_atmel_ca_server,test_atmel_ca_client=test_atmel_ca_client)

   if not gen_only and gen_rsa and run_rsa:
      results = run_tests(fname_drv,test_db)
      all_results = merge_dicts(all_results,results)
      #testing_summary(all_results,log_dir)
      
#   sys.exit(0)

   #
   # ECDSA test ciphers
   #
   ciphers_ecdsa = [
      'ECDH-ECDSA-AES128-SHA',
      'ECDH-ECDSA-AES128-SHA256',
      'ECDH-ECDSA-AES128-GCM-SHA256',
      'ECDHE-ECDSA-AES128-SHA',
      'ECDHE-ECDSA-AES128-SHA256',
      'ECDHE-ECDSA-AES128-GCM-SHA256'
   ]

   # ECDSA tests - Atmel CA defined
   cipher_db = ciphers_ecdsa
   test_engine_server = [0,1]
   test_engine_client = [0,1]
   #test_example_server = [0,1]
   test_example_server = [0]
   #test_example_client = [0,1]
   test_example_client = [0]
   test_atmel_ca_server = [0,1]
   test_atmel_ca_client = [0,1]

   valid_combos = [
      (0,1,1,0,0,0),
      (1,1,1,0,0,0),
      (0,0,0,0,1,1),
      (0,0,0,1,1,1),
      (1,1,1,1,0,0),
      (1,0,0,1,1,1)
   ]

   test_script_pairs = [('run_client.sh','run_server.sh')]
   if gen_ecdsa_0:
      test_db = create_tests(base_dir,log_dir,test_script_pairs,cipher_db,test_engine_server=test_engine_server,test_engine_client=test_engine_client,test_example_server=test_example_server,test_example_client=test_example_client,valid_combos=valid_combos,test_atmel_ca_server=test_atmel_ca_server,test_atmel_ca_client=test_atmel_ca_client,cert_test=True,atmel_ca=True)

   if not gen_only and gen_ecdsa_0 and run_ecdsa_0:
      results = run_tests(fname_drv,test_db)
      all_results = merge_dicts(all_results,results)
      #testing_summary(all_results,log_dir)

#   sys.exit(0)

#
# No Atmel CA defined
#
   # DER
   cipher_db = ciphers_ecdsa
   test_engine_server = [0,1]
   test_engine_client = [0,1]
   #test_example_server = [0,1]
   test_example_server = [0]
   #test_example_client = [0,1]
   test_example_client = [0]
   test_atmel_ca_server = [0,1]
   test_atmel_ca_client = [0,1]

   valid_combos = [
      (0,0,0,0,0,0),
      (1,0,0,0,0,0),
      (0,0,0,1,0,0),
      (1,0,0,1,0,0),
      (0,1,0,0,0,0),
      (1,1,0,0,0,0),
      (0,0,0,0,1,0),
      (0,0,0,1,1,0),
      (1,1,0,1,0,0),
      (1,0,0,1,1,0)
   ]

   test_script_pairs = [('run_der_client.sh','run_server.sh'),('run_client.sh','run_der_server.sh')]

   if gen_ecdsa_1:
      test_db = create_tests(base_dir,log_dir,test_script_pairs,cipher_db,test_engine_server=test_engine_server,test_engine_client=test_engine_client,test_example_server=test_example_server,test_example_client=test_example_client,valid_combos=valid_combos,test_atmel_ca_server=test_atmel_ca_server,test_atmel_ca_client=test_atmel_ca_client,cert_test=True,atmel_ca=False)

   if not gen_only and gen_ecdsa_1 and run_ecdsa_1:
      results = run_tests(fname_drv,test_db)
      all_results = merge_dicts(all_results,results)

   if run_rsa or run_ecdsa_0 or run_ecdsa_1:
      testing_summary(all_results,log_dir)

