#!/usr/bin/python
# Copyright (C) 2011 Denis Bilenko (http://denisbilenko.com)
# Homepage: http://github.com/denik/cython-ifdef
import sys
import os
import datetime
from cStringIO import StringIO
import difflib
import uuid

# #if XXX require different configuration to process than #ifdef
# while for "#ifdef XXX" it's enough to do "-DXXX" and "-UXXX",
# for "#if XXX", we need "-DXXX=1" and "-DXXX=0".


newline_token = ' %s ' % uuid.uuid4().hex


class options:
    output = None
    verbose = False


def get_symbols(filename):
    command = "unifdef -t -s '%s'" % filename
    popen = os.popen(command)
    result = popen.read().strip().split()
    returncode = popen.close()
    if returncode is not None:
        sys.exit('%r failed with code %r' % (command, os.WEXITSTATUS(returncode)))
    return sorted(set(result))


def parse_commandline():
    argv = sys.argv[1:]
    if not argv:
        sys.exit('Usage: %s [cython-options] sourcefile.pyx' % sys.argv[0])
    sourcefile = options.sourcefile = argv[-1]
    del argv[-1]
    if not os.path.exists(sourcefile):
        sys.exit('File not found: %s' % sourcefile)
    if not sourcefile.endswith('.pyx') and not sourcefile.endswith('.py'):
        sys.exit('Invalid extension: %s' % sourcefile)
    try:
        index = argv.index('-o')
    except ValueError:
        try:
            index = argv.index('--output-file')
        except ValueError:
            path, name = os.path.split(sourcefile)
            name = name.rsplit('.', 1)[0] + '.c'
            options.output = os.path.join(path, name)
    if options.output is None:
        try:
            del argv[index]
            options.output = argv[index]
            del argv[index]
        except IndexError:
            sys.exit('Invalid command line: %s' % (sys.argv, ))
    options.cython_args = ' '.join(argv)


def system(command):
    print command
    result = os.system(command)
    if result:
        sys.exit('%r failed with code %s' % (command, os.WEXITSTATUS(result)))


def system_unifdef(command):
    print command
    result = os.system(command)
    result = os.WEXITSTATUS(result)
    if result not in (0, 1):
        sys.exit('%r failed with code %s' % (command, result))


def unlink(filename):
    try:
        os.unlink(filename)
    except OSError, ex:
        if 'no such file' not in str(ex).lower():
            raise


def link_force(source, dest):
    unlink(dest)
    return os.link(source, dest)


class Config(object):

    def __init__(self, key):
        self.key = key
        label = key.replace(' ', '_').replace('-', '_')
        sourcename = os.path.basename(options.sourcefile)
        base = sourcename.rsplit('.', 1)[0] + label
        self.pyx_name = base + '.pyx'
        self.c_name = base + '.c'

    def __str__(self):
        return self.key


def convert_comments(filename, today):
    output = open(filename + '.temp', 'w')
    input = open(filename)
    firstline = input.readline()

    if firstline.strip().lower().startswith('/* generated by cython ') and firstline.strip().endswith('*/'):
        line = firstline.strip().strip('/*').strip().split(' on ')[0]
        output.write('/* ' + line + ' + cython_ifdef.py on %s */\n' % today)
    else:
        output.write(firstline)

    in_comment = False
    for line in input:

        if line.endswith('\n'):
            line = line[:-1].rstrip() + '\n'

        if in_comment:
            if '*/' in line:
                in_comment = False
                output.write(line)
            else:
                output.write(line.replace('\n', newline_token))
        else:
            if line.lstrip().startswith('/* ') and '*/' not in line:
                line = line.lstrip()  # cython adds space before /* for some reason
                line = line.replace('\n', newline_token)
                output.write(line)
                in_comment = True
            else:
                output.write(line)
    output.flush()
    output.close()
    os.rename(filename + '.temp', filename)


def compact_tag_set(tags):
    for tag in tags.copy():
        prefix, symbol = tag[:2], tag[2:]
        reverse = {'-D': '-U', '-U': '-D'}.get(prefix)
        if reverse is None:
            raise ValueError('Cannot process: %r' % (tag, ))
        reverse += symbol
        if reverse in tags:
            tags.discard(tag)
            tags.discard(reverse)


class Str(str):

    def __new__(cls, string, tag=None):
        if tag is None:
            tag = getattr(string, 'tag', set())
        self = str.__new__(cls, string)
        self.string = string
        self.tag = set(tag)
        return self

    def __repr__(self):
        return '%s(%s, %r)' % (self.__class__.__name__, str.__repr__(self), self.tag)

    def __add__(self, other):
        newtag = self.tag | getattr(other, 'tag', set())
        return self.__class__(str.__add__(self, other), newtag)

    def __radd__(self, other):
        newtag = self.tag | getattr(other, 'tag', set())
        return self.__class__(str.__add__(other, self), newtag)

    methods = ['__getslice__', '__getitem__', '__mul__', '__rmod__', '__rmul__',
               'join', 'replace', 'upper', 'lower']

    for method in methods:
        exec '''def %s(self, *args):
    return self.__class__(str.%s(self, *args), self.tag)''' % (method, method)


def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
                 tofiledate='', n=1000000, lineterm='\n'):
    started = False
    for group in difflib.SequenceMatcher(None, a, b).get_grouped_opcodes(n):
        if not started:
            started = True
        i1, i2, j1, j2 = group[0][1], group[-1][2], group[0][3], group[-1][4]
        for tag, i1, i2, j1, j2 in group:
            if tag == 'equal':
                assert i2 - i1 == j2 - j1, locals()
                for line_a, line_b in zip(a[i1:i2], b[j1:j2]):
                    tag = getattr(line_a, 'tag', set()) | getattr(line_b, 'tag', set())
                    line = Str(line_a, tag)
                    yield ' ' + line
                continue
            if tag == 'replace' or tag == 'delete':
                for line in a[i1:i2]:
                    yield '-' + line
            if tag == 'replace' or tag == 'insert':
                for line in b[j1:j2]:
                    yield '+' + line


def _merge(lines1, lines2, tag1, tag2, tag3):
    tags = {'-': set(tag2),
            '+': set(tag1),
            ' ': set(tag3)}
    for line in unified_diff(lines2, lines1, n=100000):
        x = Str(line[1:])
        x.tag |= tags[line[0]]
        compact_tag_set(x.tag)
        yield x


class Source(object):

    def __init__(self, text, config):
        if isinstance(text, str):
            self.lines = StringIO(text).readlines()
        elif isinstance(text, list):
            self.lines = text
        else:
            raise TypeError('Invalid type: %r' % (text, ))
        self.key = _tags(config)
        self.config = set(self.key.split())
        self.symbols = set(x[2:] for x in self.config)

    def __repr__(self):
        return 'Source(%s lines, %r)' % (len(self.lines), self.key)


def sortkey(option):
    opt, symbol = option[:2], option[2:]
    return symbol, opt


def _tags(config):
    if isinstance(config, str):
        config = config.split()
    config = set(config)
    for x in config:
        if x.startswith('-D') or x.startswith('-U') and len(x) > 2:
            pass
        else:
            raise ValueError('Bad entry %r in config %r' % (x, config))
    return ' '.join(sorted(config, key=sortkey))


def pairs(iterable):
    iterator = iter(iterable)
    while True:
        try:
            a = iterator.next()
        except StopIteration:
            return
        try:
            b = iterator.next()
        except StopIteration:
            raise AssertionError('Invalid argument for pairs: %s' % (iterable, ))
        yield (a, b)


def _bin(number, length):
    result = bin(number)[2:]
    return '0' * (length - len(result)) + result


def iter_configurations(symbols):
    size = len(symbols)
    for x in xrange(2 ** size):
        config = _bin(x, size)
        config = zip(config, symbols)
        config = ['-D' + y if x == '1' else '-U' + y for (x, y) in config]
        yield _tags(config)


def get_configurations(symbols):
    return list(iter_configurations(symbols))


def merge(sources):
    r"""
    >>> src1 = Source('hello\nworld\n', '-Dhello -Dworld')
    >>> src2 = Source('goodbye\nworld\n', '-Uhello -Dworld')
    >>> src3 = Source('hello\neveryone\n', '-Dhello -Uworld')
    >>> src4 = Source('goodbye\neveryone\n', '-Uhello -Uworld')
    >>> from pprint import pprint
    >>> pprint(merge([src1, src2, src3, src4]))
    [('hello\n', '-Dhello'),
     ('goodbye\n', '-Uhello'),
     ('world\n', '-Dworld'),
     ('everyone\n', '-Uworld')]
    """
    symbols = set()
    mapping = {}
    for source in sources:
        symbols.update(source.symbols)
        mapping[source.key] = source
    #print 'MERGE', symbols
    new_sources = []
    for keyD, keyU in pairs(get_configurations(symbols)):
        #print '#', keyD, '#', keyU
        srcD = mapping[keyD]
        srcU = mapping[keyU]
        common = srcD.config & srcU.config
        lines = list(_merge(srcD.lines, srcU.lines, srcD.config, srcU.config, common))
        #sys.stderr.write('.')
        new_sources.append(Source(lines, common))
    if not new_sources:
        raise ValueError("Something went wrong")
    elif len(new_sources) == 1:
        return [(str(x), _tags(x.tag)) for x in new_sources[0].lines]
    return merge(new_sources)


def convert_key_to_ifdef(key):
    tags = key.split()
    result = []
    if len(tags) == 1:
        tag = tags[0]
        if tag.startswith('-D'):
            return '#ifdef %s' % tag[2:]
        elif tag.startswith('-U'):
            return '#ifndef %s' % tag[2:]
    for tag in tags:
        if tag.startswith('-D'):
            result.append('defined (%s)' % tag[2:])
        elif tag.startswith('-U'):
            result.append('!defined (%s)' % tag[2:])
        else:
            raise ValueError(repr(tags))
    return '#if ' + ' && '.join(result)


def exact_reverse(a, b):
    if not a or not b:
        return
    a = a.split()
    b = b.split()
    if len(a) != 1:
        return
    if len(b) != 1:
        return
    a = a[0]
    b = b[0]
    if a[2:] != b[2:]:
        return
    if sorted([a[:2], b[:2]]) == ['-D', '-U']:
        return True


def produce_preprocessor(iterable):
    def wrap(line, log=True):
        current_line[0] += 1
        if options.verbose and log:
            sys.stdout.write('%5d: %s' % (current_line[0], line))
        return line

    state = None
    current_line = [0]
    for line, key in iterable:
        key = key or None
        if key == state:
            yield wrap(line, key)
        else:
            if exact_reverse(key, state):
                yield wrap('#else\n')
            else:
                if state:
                    yield wrap('#endif /* %s */\n' % state)
                if key:
                    yield wrap(convert_key_to_ifdef(key) + '\n')
            yield wrap(line, key)
            state = key
    if state:
        yield '#endif\n'


def main():
    parse_commandline()
    symbols = get_symbols(options.sourcefile)

    if not symbols:
        system('cython %s -o %s %s' % (options.cython_args, options.output, options.sourcefile))
        # TODO: do exec
        sys.exit(0)

    print '%s: found symbols: %s' % (options.sourcefile, ', '.join(symbols))
    today = str(datetime.date.today())
    sources = []

    tmpname = options.sourcefile + '.saved.%s' % os.getpid()
    os.rename(options.sourcefile, tmpname)

    try:
        for key in iter_configurations(symbols):
            system_unifdef('unifdef -t -b %s %s > %s' % (key, tmpname, options.sourcefile))
            system('cython %s -o %s %s' % (options.cython_args, options.output, options.sourcefile))
            convert_comments(options.output, today)
            sources.append(Source(open(options.output).read(), key))
    finally:
        os.rename(tmpname, options.sourcefile)

    sys.stderr.write('Merging (might take a while)\n')
    write = open(options.output, 'w').write
    for line in produce_preprocessor(merge(sources)):
        write(line.replace(newline_token, '\n'))


if __name__ == '__main__':
    main()