#!/usr/bin/env python3
#
# apt-metalink - Download deb packages from multiple servers concurrently
# Copyright (C) 2010 Tatsuhiro Tsujikawa
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import os.path
import subprocess
import textwrap
import sys
import optparse
import errno
import hashlib

import apt
import apt_pkg

class AptMetalink:

    def __init__(self, opts):
        self.cache = apt.Cache(apt.progress.text.OpProgress())
        self.opts = opts
        self.archive_dir = apt_pkg.config.find_dir('Dir::Cache::Archives')
        if not self.archive_dir:
            raise Exception(('No archive dir is set.'
                             ' Usually it is /var/cache/apt/archives/'))

    def upgrade(self, dist_upgrade=False):
        self.cache.upgrade(dist_upgrade=dist_upgrade)
        self._get_changes()

    def install(self, pkg_names):
        for pkg_name in pkg_names:
            if pkg_name in self.cache:
                pkg = self.cache[pkg_name]
                if not pkg.installed:
                    pkg.mark_install()
                elif pkg.is_upgradable:
                    pkg.mark_upgrade()
            else:
                raise Exception('{0} is not found'.format(pkg_name))
        self._get_changes()

    def _get_changes(self):
        pkgs = sorted(self.cache.get_changes(), key=lambda p:p.name)
        if pkgs:
            _print_update_summary(self.cache, pkgs)
            sys.stdout.write("Do you want to continue [Y/n]?")
            sys.stdout.flush()
            ans = sys.stdin.readline().strip()
            if ans and ans.lower() != 'y':
                print("Abort.")
                return
            pkgs = [pkg for pkg in pkgs if not pkg.marked_delete and \
                        not self._file_downloaded(pkg, hash_check = \
                                                      self.opts.hash_check)]
            if self.opts.metalink_out:
                with open(self.opts.metalink_out, 'w', encoding='utf-8') as f:
                    make_metalink(f, pkgs)
                return
            if not self._download(pkgs, num_concurrent=guess_concurrent(pkgs)):
                print("Some download fails. apt_pkg will take care of them.")
        if self.opts.download_only:
            print("Download complete and in download only mode")
        else:
            self.cache.commit(apt.progress.text.AcquireProgress())

    def _download(self, pkgs, num_concurrent=1):
        if not pkgs:
            return True
        partial_dir = os.path.join(self.archive_dir, 'partial')
        cmdline = [self.opts.aria2c,
                   '--metalink-file=-',
                   '--file-allocation=none',
                   '--auto-file-renaming=false',
                   '--dir={0}'.format(partial_dir),
                   '--max-concurrent-downloads={0}'.format(num_concurrent),
                   '--no-conf',
                   '--remote-time=true',
                   '--auto-save-interval=0',
                   '--continue',
                   '--split=1'
                   ]
        if self.opts.hash_check:
            cmdline.append('--check-integrity=true')

        http_proxy = apt_pkg.config.find('Acquire::http::Proxy')
        https_proxy = apt_pkg.config.find('Acquire::https::Proxy', http_proxy)
        ftp_proxy = apt_pkg.config.find('Acquire::ftp::Proxy')

        if http_proxy:
            cmdline.append('='.join(['--http-proxy', http_proxy]))
        if https_proxy:
            cmdline.append('='.join(['--https-proxy', https_proxy]))
        if ftp_proxy:
            cmdline.append('='.join(['--ftp-proxy', ftp_proxy]))

        proc = subprocess.Popen(cmdline,
                                stdin=subprocess.PIPE,
                                stdout=1,
                                stderr=2,
                                encoding='utf-8')
        make_metalink(proc.stdin, pkgs)
        proc.stdin.close()
        proc.wait()
        link_success = True
        # Link archives/partial/*.deb to archives/
        for pkg in pkgs:
            filename = get_filename(pkg.candidate)
            dst = os.path.join(self.archive_dir, filename)
            src = os.path.join(partial_dir, filename)
            ctrl_file = ''.join([src, '.aria2'])
            # If control file exists, we assume download is not
            # complete.
            if os.path.exists(ctrl_file):
                continue
            try:
                # Making hard link because aria2c needs file in
                # partial directory to know download is complete
                # in the next invocation.
                os.rename(src, dst)
            except OSError as e:
                if e.errno != errno.ENOENT:
                    print("Failed to move archive file", e)
                link_success = False
        return proc.returncode == 0 and link_success

    def _file_downloaded(self, pkg, hash_check=False):
        candidate = pkg.candidate
        path = os.path.join(self.archive_dir, get_filename(candidate))
        if not os.path.exists(path) or os.stat(path).st_size != candidate.size:
            return False
        if hash_check:
            hash_type, hash_value = get_hash(pkg.candidate)
            try:
                return check_hash(path, hash_type, hash_value)
            except IOError as e:
                if e.errno != errno.ENOENT:
                    print("Failed to check hash", e)
                return False
        else:
            return True

def check_hash(path, hash_type, hash_value):
    with open(path, 'rb') as f:
        digest = hashlib.file_digest(f, hash_type)
    return digest.hexdigest() == hash_value

def get_hash(version):
    if version.sha256:
        return ("sha256", version.sha256)
    elif version.sha1:
        return ("sha1", version.sha1)
    elif version.md5:
        return ("md5", version.md5)
    else:
        return (None, None)

def get_filename(version):
    # TODO apt-get man page said filename and basename in URI
    # could be different.
    quoted_version = version.version.replace(':', '%3a')
    return f'{version.package.shortname}_{quoted_version}_{version.architecture}.deb'

def make_metalink(out, pkgs):
    out.write('<?xml version="1.0" encoding="UTF-8"?>')
    out.write('<metalink xmlns="urn:ietf:params:xml:ns:metalink">')
    for pkg in pkgs:
        version = pkg.candidate
        hashtype, hashvalue = get_hash(version)
        out.write('<file name="{0}">'.format(get_filename(version)))
        out.write('<size>{0}</size>'.format(version.size))
        if hashtype:
            out.write('<hash type="{0}">{1}</hash>'.format(hashtype, hashvalue))
        for uri in version.uris:
            out.write('<url priority="1">{0}</url>'.format(uri))
        out.write('</file>')
    out.write('</metalink>')

def guess_concurrent(pkgs):
    max_uris = 0
    for pkg in pkgs:
        version = pkg.candidate
        max_uris = max(len(version.uris), max_uris)
    return max_uris

def pprint_names(msg, names):
    if names:
        print(msg)
        print(textwrap.fill(' '.join(names),
                            width=78,
                            initial_indent='  ',
                            subsequent_indent='  ',
                            break_long_words=False,
                            break_on_hyphens=False))

def unit_str(val):
    if val > 1000*1000:
        return '{0:.1f}MB'.format(val/1000/1000)
    elif val > 1000:
        return '{0:.0f}kB'.format(round(val/1000))
    else:
        return '{0:.0f}B'.format(val)

def _print_update_summary(cache, pkgs):
    delete_names = []
    install_names = []
    upgrade_names = []
    # TODO marked_downgrade, marked_keep, marked_reinstall
    for pkg in pkgs:
        if pkg.marked_delete:
            delete_names.append(pkg.name)
        elif pkg.marked_install:
            install_names.append(pkg.name)
        elif pkg.marked_upgrade:
            upgrade_names.append(pkg.name)
    pprint_names('The following packages will be REMOVED:', delete_names)
    pprint_names('The following NEW packages will be installed:', install_names)
    pprint_names('The following packages will be upgraded:', upgrade_names)
    print(('{0} upgraded, {1} newly installed, {2} to remove and'
           ' {3} not upgraded')\
           .format(len(upgrade_names), len(install_names), len(delete_names),
                   cache.keep_count))
    print('Need to get {0} of archives.'\
        .format(unit_str(cache.required_download)))
    if cache.required_space < 0:
        print('After this operation, {0} disk space will be freed.'\
            .format(unit_str(-cache.required_space)))
    else:
        print(('After this operation, {0} of additional disk space will'
               ' be used.').format(unit_str(cache.required_space)))

def main():
    usage = 'Usage: %prog [options] {upgrade | dist-upgrade | install pkg ...}'
    parser = optparse.OptionParser(usage=usage)
    parser.add_option('-d', '--download-only', action='store_true',
                      help="Download only. [default: %default]")
    parser.add_option('--metalink-out', metavar="FILE",
                      help=("""\
Instead of fetching the files, Metalink XML document is saved to given
FILE. Metalink XML document contains package's URIs and checksums.
"""))
    parser.add_option('--hash-check', action="store_true",
                      help=("Check hash of already downloaded files."
                            " If hash check fails, download file again."))
    parser.add_option('-x', '--aria2c' ,dest='aria2c',
                      help="path to aria2c executable [default: %default]")

    parser.set_defaults(download_only=False)
    parser.set_defaults(hash_check=False)
    parser.set_defaults(aria2c='/usr/bin/aria2c')
    opts, args = parser.parse_args()

    if not args:
        print('No command is given.')
        parser.print_usage()
        exit(1)

    command = args[0]
    am = AptMetalink(opts)
    if command == 'upgrade':
        am.upgrade()
    elif command == 'dist-upgrade':
        am.upgrade(dist_upgrade=True)
    elif command == 'install':
        am.install(args[1:])
    else:
        print("Command {0} is not supported.".format(command))

if __name__ == '__main__':
    main()