#!/usr/bin/env python3
'''
# Copyright (C) 2020, Elphel.inc.
# Usage: known
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http:#www.gnu.org/licenses/>.

@author:     Oleg K Dzhimiev, Konstantyn Chebanov
@copyright:  2016 Elphel, Inc.
@license:    GPLv3.0+
@contact:    oleg@elphel.com
@deffield    updated: unknown

'''

__author__ = "Elphel"
__copyright__ = "Copyright 2020, Elphel, Inc."
__license__ = "GPL"
__version__ = "4.0"
__maintainer__ = "Oleg K Dzhimiev"
__email__ = "oleg@elphel.com"
__status__ = "Development"

import subprocess
import sys
import os
import time
import json
import shlex
import stat

MIN_DEVICE_BYTES = 7 * 1024 * 1024 * 1024
MAX_DEVICE_BYTES = 17 * 1024 * 1024 * 1024

# functions
# useful link 1: http://superuser.com/questions/868117/layouting-a-disk-image-and-copying-files-into-it
# useful link 2: poky/scripts/contrib/mkefidisk.sh
# (?) useful link 3: http://unix.stackexchange.com/questions/53890/partitioning-disk-image-file

def shout(cmd):
    """Execute shell command and print to console"""
    subprocess.call(cmd, shell=True)

def timestamp():
    """Return human-readable local timestamp."""
    return time.strftime("%Y-%m-%d %H:%M:%S")

def begin_phase(name):
    """Start a timed phase."""
    print(f"[{timestamp()}] {name}")
    return time.monotonic()

def end_phase(name, start_t):
    """Finish a timed phase."""
    elapsed = time.monotonic() - start_t
    print(f"[{timestamp()}] {name} done in {elapsed:.1f}s")

def print_help():
    """Print help information"""
    print("\nDescription:\n")
    print("  * Required programs: lsblk, kpartx, parted")
    print("  * Run under superuser. Make sure the correct device is provided.")
    print("  * Erases partition table on the provided device")
    print("  * If given someimage.img file - burns the sd card from it")
    print("  * If not - uses the files from the predefined list")
    print("  * Creates FAT32 partition labeled 'BOOT' and copies files required for boot")
    print("  * Creates EXT4 partition labeled 'root' and extracts rootfs.tar.gz")
    print("\nExamples:\n")
    print("  * Use files (names are hardcoded) from the current dir ('build/tmp/deploy/images/elphel393/mmc/'):")
    print("      ~$ sudo write_bootable_mmc.py /dev/sdz")
    print("  * Use someimage.img file:")
    print("      ~$ sudo write_bootable_mmc.py /dev/sdz someimage.img")
    print("  * Auto-detect likely removable 8/16GB devices and pick interactively:")
    print("      ~$ sudo write_bootable_mmc.py")
    print("  * To write *.iso use a standard OS tool that burns bootable USB drives")
    print("")

def list_block_devices():
    """Return top-level block devices from lsblk."""
    out = subprocess.check_output(
        ["lsblk", "-b", "-J", "-o", "NAME,PATH,TYPE,SIZE,RM,TRAN,MODEL,SERIAL,MOUNTPOINTS"],
        text=True
    )
    data = json.loads(out)
    return data.get("blockdevices", [])

def has_mounted_children(dev):
    """Return True if any child partition is mounted."""
    for child in dev.get("children", []) or []:
        mountpoints = child.get("mountpoints")
        if isinstance(mountpoints, list):
            if any(mp for mp in mountpoints if mp):
                return True
        elif mountpoints:
            return True
        if has_mounted_children(child):
            return True
    return False

def as_disk_entry(dev):
    """Normalize lsblk disk entry."""
    path = dev.get("path")
    if not path:
        name = dev.get("name", "")
        path = f"/dev/{name}" if name else ""
    size = int(dev.get("size") or 0)
    model = (dev.get("model") or "").strip()
    serial = (dev.get("serial") or "").strip()
    tran = (dev.get("tran") or "").strip().lower()
    rm = int(dev.get("rm") or 0)
    return {
        "path": path,
        "size": size,
        "model": model,
        "serial": serial,
        "tran": tran,
        "rm": rm,
        "mounted": has_mounted_children(dev),
    }

def get_disks():
    """Return normalized disk list."""
    disks = []
    for dev in list_block_devices():
        if dev.get("type") == "disk":
            disks.append(as_disk_entry(dev))
    return disks

def format_gib(size_bytes):
    """Format size in GiB."""
    return f"{size_bytes / (1024.0 ** 3):.1f} GiB"

def choose_device_interactive(disks):
    """Choose target device from likely removable 8/16GB disks."""
    candidates = []
    for d in disks:
        if d["size"] < MIN_DEVICE_BYTES or d["size"] > MAX_DEVICE_BYTES:
            continue
        if d["mounted"]:
            continue
        if d["rm"] == 1 or d["tran"] in ("usb", "mmc", "sdio"):
            candidates.append(d)

    if not candidates:
        print("ERROR: no safe removable 8/16GB candidate device found.")
        print("Disks discovered:")
        for d in disks:
            mounted = "mounted" if d["mounted"] else "not-mounted"
            print(f"  {d['path']:<14} {format_gib(d['size']):>8} rm={d['rm']} tran={d['tran'] or '-'} {mounted}")
        print("Provide target device explicitly, for example:")
        print(f"  sudo {os.path.basename(sys.argv[0])} /dev/sdX")
        sys.exit(1)

    print("Select target block device:")
    for i, d in enumerate(candidates, start=1):
        label = " ".join(x for x in (d["model"], d["serial"]) if x).strip()
        print(
            f"  [{i}] {d['path']:<14} {format_gib(d['size']):>8} "
            f"rm={d['rm']} tran={d['tran'] or '-'} {label}"
        )

    while True:
        val = input(f"Enter number [1-{len(candidates)}] or 'q' to quit: ").strip().lower()
        if val in ("q", "quit", "exit"):
            sys.exit(1)
        if val.isdigit():
            idx = int(val)
            if 1 <= idx <= len(candidates):
                return candidates[idx - 1]["path"]
        print("Invalid selection.")

def get_disk_meta(disks, device):
    """Return metadata for a selected disk path."""
    for d in disks:
        if d["path"] == device:
            return d
    return None

def ensure_running_as_root():
    """Require sudo/root execution."""
    if os.geteuid() != 0:
        cmd = " ".join(shlex.quote(a) for a in sys.argv)
        print("ERROR: this program must be launched with sudo.")
        print(f"Try:\n  sudo {cmd}")
        sys.exit(1)

def ensure_block_device(device):
    """Validate that provided path is a block device."""
    if not os.path.exists(device):
        print(f"No such device: {device}")
        sys.exit(1)
    mode = os.stat(device).st_mode
    if not stat.S_ISBLK(mode):
        print(f"Not a block device: {device}")
        sys.exit(1)

def confirm_erase(device, meta):
    """Ask for explicit confirmation before erasing."""
    if meta:
        size = format_gib(meta["size"])
        tran = meta["tran"] or "-"
        model = " ".join(x for x in (meta["model"], meta["serial"]) if x).strip()
        print(f"Selected: {device} ({size}, rm={meta['rm']}, tran={tran}) {model}".rstrip())
        if meta["size"] < MIN_DEVICE_BYTES or meta["size"] > MAX_DEVICE_BYTES:
            print("WARNING: selected device size is outside expected 8/16GB range.")
    print(f"WARNING: this will erase all data on {device}.")
    answer = input("Type YES to continue: ").strip()
    if answer != "YES":
        print("Aborted.")
        sys.exit(1)

def check_program_installed(program):
    """Check if a program is installed"""
    try:
        result = subprocess.check_output(["which", program], stderr=subprocess.DEVNULL)
        return True
    except subprocess.CalledProcessError:
        print(f"Program missing: {program}")
        return False

# Check required programs
required_programs = (
    "lsblk",
    "parted",
    "kpartx"
)

something_is_missing = False
for program in required_programs:
    if not check_program_installed(program):
        something_is_missing = True

if something_is_missing:
    print("\nPlease install missing programs and run again.")
    sys.exit(1)

# Parse command line arguments
args = []
for arg in sys.argv[1:]:
    if arg in ("-h", "--help"):
        print_help()
        sys.exit(0)
    args.append(arg)

if len(args) > 2:
    print("ERROR: wrong number of arguments.")
    print_help()
    sys.exit(1)

ensure_running_as_root()

disks = get_disks()
if len(args) > 0:
    DEVICE = args[0]
else:
    if not sys.stdin.isatty():
        print("ERROR: no target device provided and no interactive terminal to choose one.")
        print(f"Use: sudo {os.path.basename(sys.argv[0])} /dev/sdX [image.img]")
        sys.exit(1)
    DEVICE = choose_device_interactive(disks)

if len(args) > 1:
    IMAGE_FILE = args[1]
    if not IMAGE_FILE.endswith(".img"):
        print("ERROR: Please provide *.img file or leave argument empty to use certain image files in the current dir")
        sys.exit(1)
else:
    IMAGE_FILE = ""

ensure_block_device(DEVICE)
confirm_erase(DEVICE, get_disk_meta(disks, DEVICE))

print("NOTE: If plasma crashes, do not worry")

# Parameters
SDCARD_SIZE = 4000

PT_TYPE = "msdos"

BOOT_LABEL = "BOOT"
BOOT_FS = "fat32"
BOOT_SIZE = 128
BOOT_FILE_LIST = (
    "boot.bin",
    "u-boot-dtb.img",
    "devicetree.dtb",
    "uImage"
)

ROOT_LABEL = "root"
ROOT_FS = "ext4"
ROOT_ARCHIVE = "rootfs.tar.gz"

something_is_missing = False

# Check if required files exist
if IMAGE_FILE == "":
    print(f"Preparing SD card using files: {BOOT_FILE_LIST + (ROOT_ARCHIVE,)}")
    for f in BOOT_FILE_LIST + (ROOT_ARCHIVE,):
        if not os.path.isfile(f):
            print(f"File {f} is missing")
            something_is_missing = True
else:
    print(f"Preparing SD card from {IMAGE_FILE}")
    if not os.path.isfile(IMAGE_FILE):
        print("No such file")
        something_is_missing = True

if something_is_missing:
    sys.exit(1)

# Main execution
overall_start = time.monotonic()
print(f"[{timestamp()}] Starting SD preparation for {DEVICE}")

t_phase = begin_phase(f"= Erase partition table on {DEVICE} and everything else")
shout(f"dd if=/dev/zero of={DEVICE} bs=512 count=2")
shout(f"dd if=/dev/zero of={DEVICE} bs=1MB count=1 seek=1")
shout(f"dd if=/dev/zero of={DEVICE} bs=1MB count=1 seek={BOOT_SIZE}")
end_phase("Erase", t_phase)

t_phase = begin_phase("= Create partition table")
shout(f"parted -s {DEVICE} mktable {PT_TYPE}")
end_phase("Partition table", t_phase)

t_phase = begin_phase("= Create partitions")
shout(f"parted -s {DEVICE} mkpart primary {BOOT_FS} 1 {BOOT_SIZE}")
shout(f"parted -s {DEVICE} mkpart primary {ROOT_FS} {BOOT_SIZE + 1} 100%")
# Check alignment
shout(f"parted -s {DEVICE} align-check optimal 1")
shout(f"parted -s {DEVICE} align-check optimal 2")
end_phase("Partition creation", t_phase)

# Wait for device nodes to be created
devs_created = False
partition1 = f"{DEVICE}1" if not "mmcblk" in DEVICE else f"{DEVICE}p1"
partition2 = f"{DEVICE}2" if not "mmcblk" in DEVICE else f"{DEVICE}p2"
if os.path.basename(DEVICE)[-1].isdigit():
    partition1 = f"{DEVICE}p1"
    partition2 = f"{DEVICE}p2"

t_phase = begin_phase("= Waiting for device nodes...")
while not devs_created:
    if os.path.exists(partition1) and os.path.exists(partition2):
        devs_created = True
    else:
        print(".", end="", flush=True)
        time.sleep(0.5)
print()
end_phase("Device node wait", t_phase)

time.sleep(1)
t_phase = begin_phase("= Format partitions")

shout(f"mkfs.vfat {partition1} -F 32 -n {BOOT_LABEL}")
shout(f"mkfs.ext4 {partition2} -F -L {ROOT_LABEL}")
end_phase("Formatting", t_phase)

# Create temporary mount point
shout("mkdir -p tmp")

if IMAGE_FILE == "":
    # Copy boot files
    t_phase = begin_phase("= Copying boot files...")
    shout(f"mount {partition1} tmp")
    for i in BOOT_FILE_LIST:
        print(f"    {i}")
        shout(f"cp {i} tmp")
    shout("umount tmp")
    end_phase("Boot files copy", t_phase)

    # Extract root filesystem
    t_phase = begin_phase("= Extracting root filesystem...")
    shout(f"mount {partition2} tmp")
    shout(f"tar -C tmp/ -xzpf {ROOT_ARCHIVE}")
    shout("umount tmp")
    end_phase("Rootfs extract", t_phase)
else:
    # Copy from image file
    t_phase = begin_phase("= Copying from image file...")
    shout("modprobe dm-mod")
    shout(f"kpartx -av {IMAGE_FILE}")
    end_phase("Image attach", t_phase)
    
    # Wait for mapper devices
    devs_created = False
    t_phase = begin_phase("= Waiting for mapper devices...")
    while not devs_created:
        if os.path.exists("/dev/mapper/loop0p1") and os.path.exists("/dev/mapper/loop0p2"):
            devs_created = True
        else:
            print(".", end="", flush=True)
            time.sleep(0.5)
    print()
    end_phase("Mapper device wait", t_phase)
    
    shout("mkdir -p tmp2")
    
    # Copy boot partition
    t_phase = begin_phase("= Copying boot partition...")
    shout(f"mount {partition1} tmp")
    shout("mount /dev/mapper/loop0p1 tmp2")
    shout("rsync -a tmp2/ tmp")
    shout("umount tmp")
    shout("umount tmp2")
    end_phase("Boot partition copy", t_phase)
    
    # Copy root partition
    t_phase = begin_phase("= Copying root partition...")
    shout(f"mount {partition2} tmp")
    shout("mount /dev/mapper/loop0p2 tmp2")
    shout("rsync -a tmp2/ tmp")    
    shout("umount tmp")
    shout("umount tmp2")
    end_phase("Root partition copy", t_phase)
    
    shout("rm -rf tmp2")
    t_phase = begin_phase("= Detaching image mappers...")
    shout(f"kpartx -dv {IMAGE_FILE}")
    end_phase("Image detach", t_phase)

# Clean up
shout("rm -rf tmp")

total_elapsed = time.monotonic() - overall_start
print(f"\n[{timestamp()}] Done! SD card is ready for use in {total_elapsed:.1f}s.")
print(f"You can now safely remove {DEVICE}")
