#!/usr/bin/env python
#
# Copyright (C) 2010 Oracle. All rights reserved.
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU General Public License as published by the Free Software
# Foundation, version 2.  This program is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General
# Public License for more details.  You should have received a copy of the GNU
# General Public License along with this program; If not, see <http://www.gnu.org/licenses/>.

import sys
import os
import stat
import time
import string
import random
import tempfile
import commands
import subprocess
import urlgrabber
from optparse import OptionParser


XEN_PATHS = [
    ('images/xen/vmlinuz', 'images/xen/initrd.img'), # Fedora <= 10 and OL = 5
    ('boot/i386/vmlinuz-xen', 'boot/i386/initrd-xen'), # openSUSE >= 10.2 and SLES >= 10
    ('boot/x86_64/vmlinuz-xen', 'boot/x86_64/initrd-xen'), # openSUSE >= 10.2 and SLES >= 10
    ('current/images/netboot/xen/vmlinuz', 'current/images/netboot/xen/initrd.gz'), # Debian
    ('images/pxeboot/vmlinuz', 'images/pxeboot/initrd.img'), # Fedora >=10 and OL >= 6
    ('isolinux/vmlinuz', 'isolinux/initrd.img'), # Fedora >= 10 and OL >= 6
]


def format_sxp(kernel, ramdisk, args):
    s = 'linux (kernel %s)' % kernel
    if ramdisk:
        s += '(ramdisk %s)' % ramdisk
    if args:
        s += '(args "%s")' % args
    return s


def format_simple(kernel, ramdisk, args, sep):
    s = ('kernel %s' % kernel) + sep
    if ramdisk:
        s += ('ramdisk %s' % ramdisk) + sep
    if args:
        s += ('args %s' % args) + sep
    s += sep
    return s


def mount(dev, path, option=''):
    if os.uname()[0] == 'SunOS':
        mountcmd = '/usr/sbin/mount'
    else:
        mountcmd = '/bin/mount'
    cmd = ' '.join([mountcmd, option, dev, path])
    (status, output) = commands.getstatusoutput(cmd)
    if status != 0:
        raise RuntimeError('Command: (%s) failed: (%s) %s' % (cmd, status, output))


def umount(path):
    if os.uname()[0] == 'SunOS':
        cmd = ['/usr/sbin/umount', path]
    else:
        cmd = ['/bin/umount', path]
    subprocess.call(cmd)


class Fetcher:
    def __init__(self, location, tmpdir):
        self.location = location
        self.tmpdir = tmpdir
        self.srcdir = location

    def prepare(self):
        if not os.path.exists(self.tmpdir):
            os.makedirs(self.tmpdir, 0750)

    def cleanup(self):
        pass

    def get_file(self, filename):
        url = os.path.join(self.srcdir, filename)
        suffix = ''.join(random.sample(string.ascii_letters, 6))
        local_name = os.path.join(self.tmpdir, 'xenpvboot.%s.%s' % (os.path.basename(filename), suffix))
        try:
            return urlgrabber.urlgrab(url, local_name, copy_local=1)
        except Exception, err:
            raise RuntimeError('Cannot get file %s: %s' % (url, err))


class MountedFetcher(Fetcher):
    def prepare(self):
        Fetcher.prepare(self)
        self.srcdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        if self.location.startswith('nfs:'):
            mount(self.location[4:], self.srcdir, '-o ro')
        else:
            if stat.S_ISBLK(os.stat(self.location)[stat.ST_MODE]):
                option = '-o ro'
            else:
                option = '-o ro,loop'
            if os.uname()[0] == 'SunOS':
                option += ' -F hsfs'
            mount(self.location, self.srcdir, option)

    def cleanup(self):
        umount(self.srcdir)
        try:
            os.rmdir(self.srcdir)
        except:
            pass


class NFSISOFetcher(MountedFetcher):
    def __init__(self, location, tmpdir):
        self.nfsdir = None
        MountedFetcher.__init__(self, location, tmpdir)

    def prepare(self):
        Fetcher.prepare(self)
        self.nfsdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        self.srcdir = tempfile.mkdtemp(prefix='xenpvboot.', dir=self.tmpdir)
        nfs = os.path.dirname(self.location[8:])
        iso = os.path.basename(self.location[8:])
        mount(nfs, self.nfsdir, '-o ro')
        option = '-o ro,loop'
        if os.uname()[0] == 'SunOS':
            option += ' -F hsfs'
        mount(os.path.join(self.nfsdir, iso), self.srcdir, option)

    def cleanup(self):
        MountedFetcher.cleanup(self)
        time.sleep(1)
        umount(self.nfsdir)
        try:
            os.rmdir(self.nfsdir)
        except:
            pass


class TFTPFetcher(Fetcher):
    def get_file(self, filename):
        if '/' in self.location[7:]:
            host = self.location[7:].split('/', 1)[0].replace(':', ' ')
            basedir = self.location[7:].split('/', 1)[1]
        else:
            host = self.location[7:].replace(':', ' ')
            basedir = ''
        suffix = ''.join(random.sample(string.ascii_letters, 6))
        local_name = os.path.join(self.tmpdir, 'xenpvboot.%s.%s' % (os.path.basename(filename), suffix))
        cmd = '/usr/bin/tftp %s -c get %s %s' % (host, os.path.join(basedir, filename), local_name)
        (status, output) = commands.getstatusoutput(cmd)
        if status != 0:
            raise RuntimeError('Command: (%s) failed: (%s) %s' % (cmd, status, output))
        return local_name


def main():
    usage = '''%prog [option]

Get boot images from the given location and prepare for Xen to use.

Supported locations:

 - http://host/path
 - https://host/path
 - ftp://host/path
 - file:///path
 - tftp://host/path
 - nfs:host:/path
 - /path
 - /path/file.iso
 - /path/filesystem.img
 - /dev/sda1
 - nfs+iso:host:/path/file.iso
 - nfs+iso:host:/path/filesystem.img'''
    version = '%prog version 0.1'
    parser = OptionParser(usage=usage, version=version)
    parser.add_option('', '--location',
                      help='The base url for kernel and ramdisk files.')
    parser.add_option('', '--kernel',
                      help='The kernel image file.')
    parser.add_option('', '--ramdisk',
                      help='The initial ramdisk file.')
    parser.add_option('', '--args',
                      help='Arguments pass to the kernel.')
    parser.add_option('', '--output',
                      help='Redirect output to this file instead of stdout.')
    parser.add_option('', '--output-directory', default='/var/run/libxl',
                      help='Output directory.')
    parser.add_option('', '--output-format', default='sxp',
                      help='Output format: sxp, simple or simple0.')
    parser.add_option('-q', '--quiet', action='store_true',
                      help='Be quiet.')
    (opts, args) = parser.parse_args()

    if not opts.location and not opts.kernel and not opts.ramdisk:
        if not opts.quiet:
            print >> sys.stderr, 'You should at least specify a location or kernel/ramdisk.'
            parser.print_help(sys.stderr)
        sys.exit(1)

    if not opts.output or opts.output == '-':
        fd = sys.stdout.fileno()
    else:
        fd = os.open(opts.output, os.O_WRONLY)

    if opts.location:
        location = opts.location
    else:
        location = ''
    if (location == ''
        or location.startswith('http://') or location.startswith('https://')
        or location.startswith('ftp://') or location.startswith('file://')
        or (os.path.exists(location) and os.path.isdir(location))):
        fetcher = Fetcher(location, opts.output_directory)
    elif location.startswith('nfs:') or (os.path.exists(location) and not os.path.isdir(location)):
        fetcher = MountedFetcher(location, opts.output_directory)
    elif location.startswith('nfs+iso:'):
        fetcher = NFSISOFetcher(location, opts.output_directory)
    elif location.startswith('tftp://'):
        fetcher = TFTPFetcher(location, opts.output_directory)
    else:
        if not opts.quiet:
            print >> sys.stderr, 'Unsupported location: %s' % location
        sys.exit(1)

    try:
        fetcher.prepare()
    except Exception, err:
        if not opts.quiet:
            print >> sys.stderr, str(err)
        fetcher.cleanup()
        sys.exit(1)

    try:
        kernel = None
        if opts.kernel:
            kernel = fetcher.get_file(opts.kernel)
        else:
            for (kernel_path, _) in XEN_PATHS:
                try:
                    kernel = fetcher.get_file(kernel_path)
                except Exception, err:
                    if not opts.quiet:
                        print >> sys.stderr, str(err)
                    continue
                break

        if not kernel:
            if not opts.quiet:
                print >> sys.stderr, 'Cannot get kernel from loacation: %s' % location
            sys.exit(1)

        ramdisk = None
        if opts.ramdisk:
            ramdisk = fetcher.get_file(opts.ramdisk)
        else:
            for (_, ramdisk_path) in XEN_PATHS:
                try:
                    ramdisk = fetcher.get_file(ramdisk_path)
                except Exception, err:
                    if not opts.quiet:
                        print >> sys.stderr, str(err)
                    continue
                break
    finally:
        fetcher.cleanup()

    if opts.output_format == 'sxp':
        output = format_sxp(kernel, ramdisk, opts.args)
    elif opts.output_format == 'simple':
        output = format_simple(kernel, ramdisk, opts.args, '\n')
    elif opts.output_format == 'simple0':
        output = format_simple(kernel, ramdisk, opts.args, '\0')
    else:
        print >> sys.stderr, 'Unknown output format: %s' % opts.output_format
        sys.exit(1)

    sys.stdout.flush()
    os.write(fd, output)


if __name__ == '__main__':
    main()