#!/usr/bin/python3
#  Author: Jamie Strandboge <jamie@ubuntu.com>
#  Copyright (C) 2013-2015 Canonical Ltd.
#
#  This script is distributed under the terms and conditions of the GNU General
#  Public License, Version 3 or later. See http://www.gnu.org/copyleft/gpl.html
#  for details.

from __future__ import print_function
import glob
import optparse
import os
import re
import shutil
import subprocess
import sys
import tempfile
import unittest
import yaml

topdir = None
debugging = False


# Force 15.1 to be 15.10. This is not ideal, but we lose data with yaml that
# has: policy_version: 15.10
def float_representer(dumper, data):
    value = str(data)
    if re.search(r'^[0-9][0-9]\.[0-9]$', value):
        value += '0'

    return dumper.represent_str(value)
yaml.add_representer(float, float_representer)


def recursive_rm(dirPath, contents_only=False):
    '''recursively remove directory'''
    names = os.listdir(dirPath)
    for name in names:
        path = os.path.join(dirPath, name)
        if os.path.islink(path) or not os.path.isdir(path):
            os.unlink(path)
        else:
            recursive_rm(path)
    if contents_only is False:
        os.rmdir(dirPath)


def cmd(command):
    '''Try to execute the given command.'''
    try:
        sp = subprocess.Popen(command, stdout=subprocess.PIPE,
                              stderr=subprocess.STDOUT,
                              universal_newlines=True)
    except OSError as ex:
        return [127, str(ex)]

    out = sp.communicate()[0]
    return [sp.returncode, str(out)]


def debug(s):
    global debugging
    if not debugging:
        return
    print("DEBUG: %s" % s)


class T(unittest.TestCase):
    def setUp(self):
        '''Setup for tests'''
        global topdir
        self.data_dir = os.path.join(topdir, 'data', 'seccomp')

        self.tmpdir = tempfile.mkdtemp(prefix='test-sc-filtergen-')
        shutil.copytree(os.path.join(self.data_dir, 'templates'),
                        os.path.join(self.tmpdir, 'templates'))
        shutil.copytree(os.path.join(self.data_dir, 'policygroups'),
                        os.path.join(self.tmpdir, 'policygroups'))

        self.name = "testme"
        self.appname = "foo"
        self.version = "0.1"
        self.profile_name = "%s_%s_%s" % (self.name, self.appname,
                                          self.version)

        # create a profile
        self.sc = dict()
        self.sc['policy_version'] = float(15.10)
        self.sc['policy_vendor'] = "ubuntu-personal"

    def _add_policy_group(self, g):
        if 'policy_groups' not in self.sc:
            self.sc['policy_groups'] = []

        if g not in self.sc['policy_groups']:
            self.sc['policy_groups'].append(g)

    def _del_policy_group(self, g, name=None):
        if g in self.sc['policy_groups']:
            self.sc['policy_groups'].remove(g)
        if len(self.sc['policy_groups']) == 0:
            del self.sc['policy_groups']

    def _update_template(self, t, name=None):
        self.sc['template'] = t

    def _update_policy_vendor(self, v):
        self.sc['policy_vendor'] = v

    def _update_policy_version(self, v):
        self.sc['policy_version'] = float(v)

    def tearDown(self):
        '''Clean up after each test_* function'''
        if os.path.exists(self.tmpdir):
            recursive_rm(self.tmpdir)

    def _emit_yaml(self):
        '''Emit yaml'''
        lines = []
        p = re.compile("(policy_version: )'(.*)'")
        for line in yaml.dump(self.sc,
                              default_flow_style=False,
                              indent=4).splitlines():
            line = p.sub('\\1\\2', line)
            lines.append(line)
        return "\n".join(lines)

    def _gen_test_data(self, vendor, version):
        '''Generate some test data'''
        for i in 'templates', 'policygroups':
            os.mkdir(os.path.join("%s/%s" % (self.tmpdir, i), vendor))
            os.mkdir(os.path.join("%s/%s/%s" % (self.tmpdir, i, vendor),
                     str(version)))

            for j in ['common', 'restricted']:
                fn = os.path.join("%s/%s/%s/%s" % (self.tmpdir,
                                                   i,
                                                   vendor,
                                                   str(version)),
                                  "%s_%s" % (j, i))
                f = open(fn, 'w')
                f.write("%s_%s\n" % (j, i))
                f.close()

    def _filtergen(self, keep=False):
        '''Run filtergen'''
        contents = self._emit_yaml()
        debug("\n" + contents)
        out_fn = os.path.join(self.tmpdir, self.profile_name)
        m = os.path.join(self.tmpdir, "manifest")
        open(m, 'w').write(contents)
        rc, out = cmd(['sc-filtergen',
                       '--policy-dir=%s' % self.tmpdir,
                       '--manifest=%s' % m,
                       '--output-file=%s' % out_fn,
                       ])
        self.assertTrue(rc == 0,
                        "sc-filtergen exited with error:\n%s\n%s\n[%d]" %
                        (contents, out, rc))

        for fn in glob.glob("%s/*" % out_fn):
            debug(fn)
            debug("\n%s" % open(fn, 'r').read())

        if not keep and os.path.exists(out_fn):
            os.unlink(out_fn)

    def test_templates_parse(self):
        '''Test that templates parse'''
        debug("")
        for vendor_dir in glob.glob("%s/templates/*" % self.tmpdir):
            vendor = os.path.basename(vendor_dir)
            for version_dir in glob.glob("%s/*" % vendor_dir):
                version = os.path.basename(version_dir)
                self._update_policy_version(version)
                self._update_policy_vendor(vendor)
                for template_fn in glob.glob("%s/*" % version_dir):
                    template = os.path.basename(template_fn)
                    self._update_template(template)
                    debug("%s/%s/%s" % (vendor, version, template))
                    self._filtergen()

    def test_filtergen_output(self):
        '''Test sc-filtergen output'''
        vendor = "ubuntu-test"
        version = 1000.1001
        self._update_policy_version(version)
        self._update_policy_vendor(vendor)

        self._gen_test_data(vendor, version)
        contents = ""
        for i in 'templates', 'policygroups':
            for j in ['common', 'restricted']:
                self.sc['template'] = "%s_templates" % (j)
                self.sc['policy_groups'] = ["%s_policygroups" % (j)]

                self._filtergen(keep=True)

                fn = os.path.join(self.tmpdir, self.profile_name)
                search = "%s_%s" % (j, i)
                debug("Checking '%s' for '%s'" % (fn, search))
                contents = open(fn, 'r').read()
                debug("%s contains:\n%s" % (os.path.basename(fn), contents))
                os.unlink(fn)

                self.assertTrue(search in contents, "Could not find "
                                "'%s' in '%s'" % (search, contents))

    def test_policygroups_parse(self):
        '''Test that policygroups parse'''
        for vendor_dir in glob.glob("%s/policygroups/*" % self.tmpdir):
            vendor = os.path.basename(vendor_dir)
            for version_dir in glob.glob("%s/*" % vendor_dir):
                version = os.path.basename(version_dir)
                self._update_policy_version(version)
                self._update_policy_vendor(vendor)
                for group_fn in glob.glob("%s/*" % version_dir):
                    group = os.path.basename(group_fn)
                    self._add_policy_group(group)
                    for template_fn in glob.glob("%s/templates/%s/%s/*" % (
                                                 self.tmpdir,
                                                 vendor,
                                                 version)):
                        template = os.path.basename(template_fn)
                        self._update_template(template)
                        debug("%s/%s/%s (%s)" % (vendor, version, group,
                                                 template))
                        self._filtergen()
                    self._del_policy_group(group)

#
# Main
#
if __name__ == '__main__':
    parser = optparse.OptionParser()
    parser.add_option("-d", "--debug",
                      dest='debug',
                      help='emit debugging information',
                      action='store_true',
                      default=False)
    (opt, args) = parser.parse_args()
    if opt.debug:
        debugging = True

    absfn = os.path.abspath(sys.argv[0])
    topdir = os.path.dirname(os.path.dirname(absfn))

    if len(sys.argv) > 1 and sys.argv[1] == '-d':
        debugging = True

    # run the tests
    suite = unittest.TestSuite()
    suite.addTest(unittest.TestLoader().loadTestsFromTestCase(T))
    rc = unittest.TextTestRunner(verbosity=2).run(suite)

    if not rc.wasSuccessful():
        sys.exit(1)
