#!/usr/bin/env python3

import datetime
import errno
import os
import os.path
import re
import shlex
import subprocess
import sys
import tempfile
import time
import textwrap
import logging
from argparse import ArgumentParser
from contextlib import ExitStack


ZPOOL_NAME = 'test'
ZPOOL_DIR = '/zpools'
DATA_DIR = '/home/nikratio/tmp/buckets'
USE_BCACHE = True
USE_S3BACKER_CACHE = False
S3_BLOCKSIZE_SMALL_KB = 4
S3_BLOCKSIZE_LARGE_KB = 256
ZFS_BLOCKSIZE_THRESHOLD_KB = 256
ZFS_RECORDSIZE_KB = 512
KEYFILE = "/home/nikratio/my_secret.key"
BUCKET_NAME = 'my-amazon-s3-bucket'
BUCKET_REGION = 'eu-west-2'
DSET_NAME = 'laptop'
KEEP_BACKUPS = [ 1, 2, 3, 5, 7, 10, 14, 20, 31, 60, 90, 180, 360, 540, 720 ]
RSYNC_OPTS = [ "-aHAXi", "--delete-during", "--delete-excluded", "--partial",
               "--inplace", "--sparse" ]



log = logging.getLogger(None if __name__ == '__main__' else __name__)


def wait_for(pred, *args, timeout=5, desc='(unspecified)'):
    print(f'Waiting for {desc}...')
    waited = 0
    while waited < timeout:
        if pred(*args):
            return
        time.sleep(1)
        waited += 1
    raise TimeoutError()


def run(cmdline, wait=True, **kwargs):
    log.debug('$ %s', shlex.join(cmdline))
    if wait:
        return subprocess.run(cmdline, **kwargs, check=True)
    else:
        return subprocess.Popen(cmdline, **kwargs)


def find_unused_nbd():
    if not os.path.exists('/sys/block/nbd0'):
        return RuntimeError("Can't find NBDs - is the nbd module loaded?")
    for devno in range(20):
        if not os.path.exists(f'/sys/block/nbd{devno}/pid'):
            return devno
    raise RuntimeError("Can't find any available NBDs")


def find_unused_loop():
    if not os.path.exists('/sys/block/loop0'):
        return RuntimeError("Can't find NBDs - is the nbd module loaded?")
    for devno in range(20):
        if not os.path.exists(f'/sys/block/loop{devno}/loop'):
            return devno
    raise RuntimeError("Can't find any available loop devices")


def is_mountpoint(path):
    return os.path.ismount(str(path))


def rsync(src, dst, options, exclude=[]):
    assert src.endswith('/')
    assert dst.endswith('/')
    cmd = [ 'rsync-no24' ]
    cmd += RSYNC_OPTS
    for name in exclude:
        cmd.append('--exclude')
        cmd.append(name)
    if options.bwlimit:
        cmd.append('--bwlimit')
        cmd.append(options.bwlimit)
    cmd.append(src)
    cmd.append(dst)
    run(cmd)


def write_file(path, val):
    log.debug('$ echo "%s" > %s', val, path)
    with open(path, 'w') as fh:
        print(val, file=fh)


def read_value(path):
    with open(path, 'r') as fh:
        return fh.read().strip()


def stop_nbdkits(nbdkits):
    log.info('Waiting for NBDKit instances to terminate...')
    for (kind, proc) in nbdkits.items():
        try:
            proc.wait(5)
        except subprocess.TimeoutExpired:
            log.info('Sending termination signal...')
            proc.terminate()


def start_nbdkits(exit_stack, tempdir, variant='s3', size='50G',
                  latency=None, bandwidth=None):
    sockets = {}
    nbdkits = {}

    exit_stack.callback(stop_nbdkits, nbdkits)

    log.info('Bringing up NBDs...')
    for kind in ('sb', 'lb'): # "small blocks" and "large blocks"
        socket_name = f'{tempdir}/nbd_socket_{kind}'
        nbdkit_args_1 = [
            'nbdkit', '--unix',  socket_name, '--foreground',
            '--filter=exitlast', '--filter=stats',
            '--threads', str(12 if kind == 'sb' else 2),
            '--filter=retry'
        ]
        nbdkit_args_2 = [
            f'statsfile={DATA_DIR}/{ZPOOL_NAME}_{kind}_stats.txt',
            'statsappend=true', 'statsthreshold=100',
            'retries=100', 'retry-readonly=false', 'retry-delay=30', 'retry-exponential=no'
        ]

        if latency:
            nbdkit_args_1.append('--filter=delay')
            nbdkit_args_2.extend([f'delay-write={latency}', f'delay-read={latency}'])

        if bandwidth:
            nbdkit_args_1.append('--filter=rate')
            nbdkit_args_2.append(f'rate={bandwidth}')

        if variant == 's3b':
            nbdkit_args_1.append('s3backer')
            nbdkit_args_2.extend([
                f'size={size}', f'bucket={BUCKET_NAME}/{ZPOOL_NAME}_{kind}',
                'region=eu-west-2', 'listBlocks=true' ])
            if USE_S3BACKER_CACHE:
                path = f'{DATA_DIR}/{ZPOOL_NAME}_{kind}-s3bcache.dat'
                nbdkit_args_2.extend([
                    'blockCacheNoVerify=true', 'blockCacheSize=5000',
                    'blockCacheWriteDelay=1000', 'blockCacheRecoverDirtyBlocks=true',
                    'blockCacheFileAdvise=true', f'blockCacheFile={path}',
                ])
            if kind == 'lb':
                nbdkit_args_2.append(f'blockSize={S3_BLOCKSIZE_LARGE_KB}K')
            else:
                assert kind == 'sb'
                nbdkit_args_2.append(f'blockSize={S3_BLOCKSIZE_SMALL_KB}K')

        elif variant == 's3':
            nbdkit_args_1.append('S3')
            nbdkit_args_2.extend([
                f'size={size}', f'bucket={BUCKET_NAME}', f'key={ZPOOL_NAME}_{kind}',
                'endpoint-url=http://s3.eu-west-2.amazonaws.com',
            ])
            if kind == 'lb':
                nbdkit_args_2.append(f'object-size={S3_BLOCKSIZE_LARGE_KB}K')
            else:
                assert kind == 'sb'
                nbdkit_args_2.append(f'object-size={S3_BLOCKSIZE_SMALL_KB}K')

        elif variant == 'memory':
            nbdkit_args_1.append('memory')
            nbdkit_args_2.append(size)

        elif variant == 'file':
            path = f'{DATA_DIR}/{ZPOOL_NAME}_{kind}.img'
            if not os.path.exists(path):
                raise RuntimeError(f'{path} does not exist.')
            nbdkit_args_1.append('file')
            nbdkit_args_2.append(f'file={path}')

        else:
            raise ValueError(f'Unknown plugin: {variant}')

        nbdkits[kind] = run(nbdkit_args_1 + nbdkit_args_2, wait=False, stderr=subprocess.STDOUT)
        sockets[kind] = socket_name

    return (nbdkits, sockets)


def disconnect_nbd(devname):
    run(['nbd-client', '-d', f'/dev/{devname}'])


def connect_nbds(exit_stack, sockets, bcache_register=False):
    nbd_devs = {}

    for (kind, socket) in sockets.items():
        wait_for(os.path.exists, socket, desc=f'{socket} to come up')
        devno = find_unused_nbd()
        devname = f'nbd{devno}'

        # We want the kernel to wait for NBD responses for a long time
        run(['nbd-client', '-unix', socket, '--timeout', str(7*24*60*60),
             '/dev/'+devname])
        nbd_devs[kind] = devname

        exit_stack.callback(disconnect_nbd, devname)

        if bcache_register:
            try:
                write_file('/sys/fs/bcache/register', '/dev/' + devname)
            except OSError as exc:
                # EINVAL is raised if the device is already registered
                if exc.errno != errno.EINVAL:
                    raise

        # The kernel does not honor the NBD block size, so if we limit this
        # then we increase the number of non-aligned reads+writes. Therefore,
        # set it to the maximum instead of the object size.
        #if kind == 'lb':
        #    write_file(f'/sys/block/{devname}/queue/max_sectors_kb', str(S3_BLOCKSIZE_LARGE_KB))
        #else:
        #    write_file(f'/sys/block/{devname}/queue/max_sectors_kb', str(S3_BLOCKSIZE_SMALL_KB))
        write_file(f'/sys/block/{devname}/queue/max_sectors_kb', str(32*1024))

    return nbd_devs


def wait_for_cache(nbd_devs):
    # This is another way to block until the cache is flushed, but will
    # prevent hibernation until the write completes (so not a good idea).
    # path = f'/sys/block/{nbd_dev}/bcache/cache_mode'
    # try:
    #     write_file(path, 'none')
    # except FileNotFoundError:
    #     print(f'{path} not found, continuing anyway')

    mul = { 'k': 1024, 'M': 1024*1024, 'b': 1, 'G': 1024*1024*1024 }
    last_dirty = {}
    dirty = {}
    while True:
        for nbd_dev in nbd_devs.values():
            path = f'/sys/block/{nbd_dev}/bcache/dirty_data'
            dirty[nbd_dev] = read_value(path)

        if dirty == last_dirty:
            time.sleep(10)
            continue

        last_dirty.update(dirty)
        any_dirty = False
        log.info('Dirty data: %s', ', '.join(f'{x}: {y}' for x,y in dirty.items()))
        for (dev, val) in dirty.items():
            hit = re.match(r'^([0-9.]+)([kMGb])', val)
            if not hit:
                log.warning('Unable to parse dirty data, assuming cache is flushed.')
                continue
            dirty_bytes = float(hit.group(1)) * mul[hit.group(2)]
            if dirty_bytes > 1024*1024:
                any_dirty = True

        if not any_dirty:
            break


def stop_bcache(nbd_dev):
    bcache_dev = os.path.basename(os.readlink(f'/sys/block/{nbd_dev}/bcache/dev'))
    dmesg = subprocess.Popen(['dmesg', '--follow-new'], stdout=subprocess.PIPE,
                             universal_newlines=True)
    with ExitStack() as exit_stack:
        exit_stack.callback(dmesg.terminate)

        # Stop cacheset (waits until backing devices have been stopped)
        path =  f'/sys/block/{nbd_dev}/bcache/cache/stop'
        try:
            write_file(path, '1')
        except FileNotFoundError:
            log.warning('%s not found, continuing anyway', path)

        # Stop backing device
        path =  f'/sys/block/{nbd_dev}/bcache/stop'
        try:
            write_file(path, '1')
        except FileNotFoundError:
            log.warning('%s not found, continuing anyway', path)
            return

        # Wait for bcache to fully terminate
        print(f'Waiting for {bcache_dev} ({nbd_dev}) to shut down...')
        for line in dmesg.stdout:
            if line.strip().endswith(f'bcache: bcache_device_free() {bcache_dev} stopped'):
                break


def connect_caches(exit_stack, nbd_devs, bcache_register=False):
    loop_devs = {}
    for (kind, nbd_dev) in nbd_devs.items():
        loop_dev = 'loop%d' % find_unused_loop()
        loop_devs[kind] = loop_dev
        path = f'{DATA_DIR}/{ZPOOL_NAME}_{kind}-bcache.img'
        if not os.path.exists(path):
            raise RuntimeError(f'{path} does not exist.')
        run(['losetup', '/dev/' + loop_dev, path])
        exit_stack.callback(run, ['losetup', '-d', '/dev/' + loop_dev])

        if bcache_register:
            write_file('/sys/fs/bcache/register', '/dev/' + loop_dev)
        exit_stack.callback(stop_bcache, nbd_dev)

    return loop_devs


def create_zpool(exit_stack, devs):
    for p in (f'{ZPOOL_DIR}/{ZPOOL_NAME}/{DSET_NAME}',
              f'{ZPOOL_DIR}/{ZPOOL_NAME}'):
        if os.path.exists(p):
            os.rmdir(p)

    cmdline  = ['zpool', 'create', '-f', '-R', ZPOOL_DIR ]
    for arg in ('ashift=12', 'autotrim=on', 'failmode=continue'):
        cmdline.append('-o')
        cmdline.append(arg)
    for arg in ('acltype=posixacl', 'relatime=on', 'xattr=sa', 'compression=zstd-19',
                'checksum=sha256', 'sync=disabled',
                'redundant_metadata=most', f'recordsize={ZFS_RECORDSIZE_KB*1024}',
                'encryption=on', 'keyformat=passphrase', f'keylocation=file://{KEYFILE}'):
        cmdline.append('-O')
        cmdline.append(arg)
    cmdline += [ ZPOOL_NAME, '/dev/'+ devs['lb'],
                 'special', '/dev/'+ devs['sb'] ]
    run(cmdline)
    exit_stack.callback(zpool_export, exit_stack)

    # Must be set separately, cf. https://github.com/openzfs/zfs/issues/13815
    run(['zfs', 'set',  f'special_small_blocks={ZFS_BLOCKSIZE_THRESHOLD_KB*1024}', ZPOOL_NAME])

    run(['zfs', 'create', f'{ZPOOL_NAME}/{DSET_NAME}'])


def zpool_export(exit_stack):
    # Attempting to hibernate/suspend while `zpool export` is running
    # is really bad (the system gets stuck in a semi-suspended state until
    # all filesystems are synced). So we sync as much data before as possible
    # (the pending `zpool sync` command will also prevent hibernation, but it
    # it will fail much more quickly).
    try:
        run(['zpool', 'sync', ZPOOL_NAME])
    except subprocess.CalledProcessError:
        print('zpool sync failed, aborting without cleanup...', file=sys.stderr)
        exit_stack.pop_all()
        raise

    # When exporting, we want to make sure that everything is on the backing
    # device.
    write_file('/sys/module/zfs/parameters/zfs_nocacheflush', 0)
    write_file('/sys/module/zfs/parameters/zil_nocacheflush', 0)

    while True:
        try:
            run(['systemd-inhibit', '--what=sleep:shutdown:idle', '--who=remote-backup',
                 '--why=zpool export running', '--mode=block', '--',
                 'zpool', 'export', ZPOOL_NAME])
            break
        except subprocess.CalledProcessError:
            print('zpool export failed, retrying...', file=sys.stderr)
        time.sleep(15)


def parse_args(args):
    '''Parse command line'''

    parser = ArgumentParser(
        description=textwrap.dedent('''\
        Make backups to Amazon S3 using ZFS over NBD.
        '''))

    parser.add_argument('--create', action='store_true', default=False,
                        help='Wipe out all existing data and recreate the pool')
    parser.add_argument('--mount-only', action='store_true', default=False,
                        help='Do not run a backup, just mount the zpool.')
    parser.add_argument('--bwlimit', type=str, default=None,
                        help='rate limit backups to this value (passed on to rsync)')
    parser.add_argument('--backend', type=str, default='s3',
                        help='Backend to use (s3, s3b, file, memory)')
    
    options = parser.parse_args(args)

    return options


def increase_task_timeout(exit_stack, timeout=300):
    # Increase the time until the kernel prints "hung task" warnings. We know
    # that tasks may appear hung for a while if data needs to be flushed
    # over the network.
    path = '/proc/sys/kernel/hung_task_timeout_secs'
    exit_stack.callback(write_file, path, read_value(path))
    write_file(path, str(timeout))


def main(args=None):
    options = parse_args(args or sys.argv[1:])
    logging.basicConfig(level=logging.DEBUG, format='%(message)s')

    if os.geteuid() != 0:
        log.error('This script must be run as root.', file=sys.stderr)
        sys.exit(3)

    # Make sure that our own output comes before the output of commands
    # that we call.
    sys.stdout.reconfigure(line_buffering=True)
    sys.stderr.reconfigure(line_buffering=True)

    with ExitStack() as exit_stack:
        increase_task_timeout(exit_stack)

        if USE_BCACHE:
            mount_with_bcache(exit_stack, options)
        else:
            mount(exit_stack, options)

        mountpoint = f'{ZPOOL_DIR}/{ZPOOL_NAME}/{DSET_NAME}'
        assert is_mountpoint(mountpoint)

        if options.mount_only:
            print('Press enter to unmount and exit.')
            sys.stdin.readline()
        else:
            run_backup(mountpoint, options)


def mount_with_bcache(exit_stack, options):
    tempdir = exit_stack.enter_context(tempfile.TemporaryDirectory())

    (nbdkits, sockets) = start_nbdkits(exit_stack, tempdir, variant=options.backend)
    nbd_devs = connect_nbds(exit_stack, sockets, bcache_register=not options.create)
    cache_devs = connect_caches(exit_stack, nbd_devs, bcache_register=not options.create)

    bcache_devs = {}
    for kind in nbd_devs.keys():
        nbd_dev = nbd_devs[kind]
        cache_dev = cache_devs[kind]

        if options.create:
            # Maximum bcache blocksize is page size (4kb), data offset has to be
            # at least 16 sectors.
            data_offset = max(16, 2*(
                S3_BLOCKSIZE_SMALL_KB if kind == 'sb' else S3_BLOCKSIZE_LARGE_KB))
            run(['make-bcache', '--data-offset', str(data_offset),
                 '--block', '4k', '--cache_replacement_policy=lru',
                 '--cache', '/dev/'+cache_dev, '--bdev', '/dev/'+nbd_dev])

        wait_for(os.path.exists, f'/sys/block/{nbd_dev}/bcache/dev', timeout=30,
                 desc=f'bcache registration of {nbd_dev}')

        bd = os.path.basename(os.readlink(f'/sys/block/{nbd_dev}/bcache/dev'))
        bcache_devs[kind] = bd
        log.debug('Found bcache device %s', bd)

        # Always use cache
        write_file(f'/sys/block/{nbd_dev}/bcache/sequential_cutoff', '0')
        write_file(f'/sys/block/{nbd_dev}/bcache/cache/congested_read_threshold_us', '0')
        write_file(f'/sys/block/{nbd_dev}/bcache/cache/congested_write_threshold_us', '0')

        # Minimize amount of dirty data
        write_file(f'/sys/block/{nbd_dev}/bcache/cache_mode', 'writeback')
        write_file(f'/sys/block/{nbd_dev}/bcache/writeback_delay', '0')
        write_file(f'/sys/block/{nbd_dev}/bcache/writeback_percent', '0')

        # Without this, some ZFS operations force data to go to the backing device
        path = '/sys/module/zfs/parameters/zfs_nocacheflush'
        exit_stack.callback(write_file, path, read_value(path))
        write_file(path, 1)
        path = '/sys/module/zfs/parameters/zil_nocacheflush'
        exit_stack.callback(write_file, path, read_value(path))
        write_file(path, 1)


    if options.create:
        create_zpool(exit_stack, bcache_devs)
    else:
        run(['zpool', 'import', '-R', ZPOOL_DIR,
             '-d', '/dev/' + bcache_devs['lb'],
             '-d', '/dev/' + bcache_devs['sb'], ZPOOL_NAME ])
        exit_stack.callback(zpool_export, exit_stack)

        run(['zfs', 'load-key', ZPOOL_NAME])
        run(['zfs', 'mount', f'{ZPOOL_NAME}/{DSET_NAME}'])

    exit_stack.callback(wait_for_cache, nbd_devs)


def mount(exit_stack, options):
    tempdir = exit_stack.enter_context(tempfile.TemporaryDirectory())

    (nbdkits, sockets) = start_nbdkits(exit_stack, tempdir, variant=options.backend)
    nbd_devs = connect_nbds(exit_stack, sockets, bcache_register=False)

    if options.create:
        create_zpool(exit_stack, nbd_devs)
    else:
        run(['zpool', 'import', '-R', ZPOOL_DIR,
             '-d', '/dev/' + nbd_devs['lb'],
             '-d', '/dev/' + nbd_devs['sb'], ZPOOL_NAME ])
        exit_stack.callback(zpool_export, exit_stack)

        run(['zfs', 'load-key', ZPOOL_NAME])
        run(['zfs', 'mount', f'{ZPOOL_NAME}/{DSET_NAME}'])


# Just a parking place for now
def run_backup(mountpoint, options):
    log.info("Running backup...")

    rsync('/usr/local/', f'{mountpoint}/usr_local/', options)
    rsync('/etc/', f'{mountpoint}/etc/', options)
    rsync('/home/', f'{mountpoint}/home/', options,
          exclude = [
              'Cache/', # e.g. .config/Code/Cache
              '/*/.mozilla/',  # covered by Firefox sync
              '__pycache__/',
              '/*/.local/share/Trash/',
          ])

    # Make snapshot
    strdate = datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
    snapshot_name = '%s/%s@%s' % (ZPOOL_NAME, DSET_NAME, strdate)
    log.info('Creating snapshot %s...', snapshot_name)
    run(['zfs', 'snapshot', snapshot_name])
    run(['expire_zfs_snapshots.py', f'{ZPOOL_NAME}/{DSET_NAME}' ]
        + [ str(x) for x in KEEP_BACKUPS],
        cwd=mountpoint)


if __name__ == '__main__':
    main()

