#!/usr/bin/env python3

# Andrew B. Collier <andrew.b.collier@gmail.com>

import datetime
import itertools
import os
import struct
import argparse
import numpy as np
from osgeo import gdal
from osgeo import osr
from PIL import Image

# TODO:
#
# - Fix transformation for GeoTIFF files (which will give correct latitude and longitude).

structs = {}


def unpack(stream, format):
    global structs
    try:
        s = structs[format]
    except KeyError:
        structs[format] = s = struct.Struct(format)
    f = stream.read(s.size)
    assert f, "unexpected eof"
    return s.unpack(f)[0]


def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks."
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


def integer(f):
    return unpack(f, b"<I")


def float64(f):
    return unpack(f, b"d")


def byte(f):
    return unpack(f, b"B")


def short(f):
    return unpack(f, b"<H")


def string(f):
    chars = iter(lambda: f.read(1), b"")
    chars = b"".join(itertools.takewhile(b"\0".__ne__, chars))

    try:
        chars = chars.decode()
    except (UnicodeDecodeError, AttributeError):
        pass

    return chars


def pointer(f, value):
    ptr = integer(f)
    offset = f.tell()
    f.seek(ptr)
    try:
        return value(f)
    finally:
        f.seek(offset)


def date_time(f):
    return datetime.datetime.fromtimestamp(integer(f))


def array(f, size, value):
    if callable(size):
        size = size(f)
    return [value(f) for _ in range(size)]


def datum_shift(f):
    return {
        "north": float64(f),
        "east": float64(f),
    }


def licence_information(f):
    data = {
        "identifier": integer(f),
    }
    f.seek(8, 1)
    data.update(
        {
            "license_description": pointer(f, string),
            "serial_number": array(f, 32, byte),
        }
    )
    f.seek(84, 1)
    return data


def digital_map_shop(f):
    return {
        "size": integer(f),
        "url": pointer(f, string),
    }


def extended_data_structure(f):
    data = {
        "map_type": pointer(f, string),
        "datum_shift": pointer(f, datum_shift),
        "disk_name": pointer(f, string),
    }
    f.seek(8, 1)
    data.update(
        {
            "license_information": pointer(f, licence_information),
            "associated_data": pointer(f, string),
            "digital_map_shop": pointer(f, digital_map_shop),
        }
    )
    return data


def map_outline_point(f):
    return {
        "lat": float64(f),
        "lon": float64(f),
    }


def meta_data(f):
    data = {
        "magic": integer(f),
        "version": integer(f),
        "width": integer(f),
        "height": integer(f),
        "long_title": pointer(f, string),
        "name": pointer(f, string),
        "identifier": pointer(f, string),
        "edition": pointer(f, string),
        "revision": pointer(f, string),
        "keywords": pointer(f, string),
        "copyright": pointer(f, string),
        "scale": pointer(f, string),
        "datum": pointer(f, string),
        "depths": pointer(f, string),
        "heights": pointer(f, string),
        "projection": pointer(f, string),
        "bit_field": integer(f),
        "original_file_name": pointer(f, string),
        "original_file_size": integer(f),
        "original_file_creation_time": date_time(f),
    }
    f.seek(4, 1)
    data["extended_data"] = pointer(f, extended_data_structure)
    data["map_outline_points"] = integer(f)

    data["map_outline"] = pointer(
        f, lambda f_: array(f_, data["map_outline_points"], map_outline_point)
    )

    return data


def deinterlace(y):
    # flip the 6 wide bit pattern (110100 -> 001011)
    return int("".join(reversed("{:06b}".format(y))), 2)


TILE_SIZE = 64


def draw_image(f, data, drawing, x, y):
    start = f.tell()
    b0 = byte(f)
    try:
        if b0 in (0, 0xFF):
            pixels = decode_huffman(f)
        elif b0 > 127:
            pixels = decode_packed(f)
        else:
            pixels = decode_rle(f)

        for row in range(TILE_SIZE):
            row = deinterlace(row) + y
            for col in range(x, x + TILE_SIZE):
                drawing[col, row] = next(pixels)
    except Exception as e:
        print("error drawing tile", start, ":", e)
        import traceback

        traceback.print_exc()


def decode_rle(f):
    f.seek(-1, 1)
    sub_palette_len = byte(f)
    sub_palette = array(f, sub_palette_len, byte)
    assert sub_palette_len
    sub_palette_len -= 1
    repeat_shift = 0
    while sub_palette_len:
        sub_palette_len >>= 1
        repeat_shift += 1
    sub_palette_mask = (1 << repeat_shift) - 1
    while True:
        b = byte(f)
        color = sub_palette[b & sub_palette_mask]
        repeat = b >> repeat_shift
        for _ in range(repeat):
            yield color


def iter_bits(f):
    while True:
        b = byte(f)
        for _ in range(8):
            if b:
                yield b & 1
                b >>= 1
            else:
                yield 0


def decode_code_book(f):
    code_book = []
    branches = 0
    colors = 0
    while colors <= branches:
        b = byte(f)
        code_book.append(b)
        if b == 128:
            code_book.append(65539 - short(f))
            code_book.append(None)
            branches += 1
        elif b < 128:
            colors += 1
        else:
            branches += 1
    return code_book


def decode_huffman(f):
    code_book = decode_code_book(f)
    if len(code_book) == 1:
        color = code_book[0]
        while True:
            yield color
    bits = iter_bits(f)
    while True:
        p = 0
        while True:
            code = code_book[p]
            if code < 128:
                break
            if next(bits):
                if code == 128:
                    p += code_book[p + 1]
                else:
                    p += 257 - code
            else:
                if code == 128:
                    p += 3
                else:
                    p += 1
        yield code


def decode_packed(f):
    f.seek(-1, 1)
    sub_palette_len = byte(f)
    sub_palette = array(f, sub_palette_len, byte)
    assert sub_palette_len
    sub_palette_len -= 1
    next_shift = 0
    while sub_palette_len:
        sub_palette_len >>= 1
        next_shift += 1
    sub_palette_mask = (1 << next_shift) - 1
    total_per_int = 32 // next_shift
    while True:
        i = integer(f)
        for _ in range(total_per_int):
            yield sub_palette[i & sub_palette_mask]
            i >>= next_shift


def qct(f, basename, scale=None):
    meta = meta_data(f)
    geo_refs = array(f, 40, float64)
    palette = [(r, g, b) for b, g, r, _ in grouper(array(f, 1024, byte)[:512], 4)]
    interp_mat = array(f, 16384, byte)

    print(meta)

    # Geographical Referencing Coefficients
    #
    geo_refs = np.asarray(geo_refs).reshape((10, 4), order="F")
    #
    # Note: These must be corrected with the datum shift from the extended metadata (if it's non-zero).

    print(geo_refs)

    width = meta["width"]
    height = meta["height"]

    image_index = array(f, width * height, integer)

    image = Image.new("P", (width * TILE_SIZE, height * TILE_SIZE), None)
    image.putpalette(sum(palette, ()))
    drawing = image.load()
    for y in range(height):
        for x in range(width):
            index = image_index[(width * y) + x]
            if index:
                f.seek(index)
                draw_image(f, 0, drawing, x * TILE_SIZE, y * TILE_SIZE)

    # Scale size of image.
    if scale and scale != 1:
        image = image.resize(
            (round(width * TILE_SIZE * scale), round(height * TILE_SIZE * scale)),
            Image.BICUBIC,
        )

    # Get actual image dimensions
    height, width = image.size

    # Convert to RGB.
    image = image.convert("RGB")

    # PNG
    #
    image.save(basename + ".png", optimize=True)

    # Convert image to an array.
    image = np.array(image)

    # GeoTIFF
    #
    tif = gdal.GetDriverByName("GTiff").Create(
        basename + ".tif",
        height,
        width,
        bands=3,
        eType=gdal.GDT_Byte,
        # Use deflate compression with horizontal differencing predictor.
        options=["COMPRESS=DEFLATE", "PREDICTOR=2"],
    )

    # This doesn't seem to give precisely the correct results.
    #
    # Complications:
    #
    # - sometimes the rectangle can be slightly rotated
    # - sometimes the corners of the map are not actually filled in, which means that the map
    #   outline doesn't always have the precise coordinates of the corners.
    #
    # Use these command line utilities to see the corner coordinates in the resulting GeoTIFF:
    #
    # - listgeo and
    # - gdalinfo.
    #
    # Look at the sample SU99.tif file for an example in this projection.
    #
    transform = (
        # Longitude of top/left pixel.
        # xmin,
        # 8.88974323e04,
        428000,
        # Pixel width.
        # (xmax - xmin) / width,
        2.5,
        # Rotation.
        0,
        # Latitude of top/left pixel.
        # ymax,
        # 2.27771698e06,
        180000,
        # Rotation.
        0,
        # Pixel height.
        # (ymin - ymax) / height,
        -2.5,
    )
    print(transform)
    tif.SetGeoTransform(transform)

    # Set projection:
    #
    # 3857  — WGS84 Pseudo-Mercator (Spherical Mercator, Google Maps, OpenStreetMap, Bing, ArcGIS, ESRI)
    # 4277  — OSGB 1936 datum used by UK Ordnance Survey maps
    # 27700 — OSGB 1936 British National Grid
    #
    srs = osr.SpatialReference()
    srs.ImportFromEPSG(27700)
    tif.SetProjection(srs.ExportToWkt())
    # Write colour channels.
    tif.GetRasterBand(1).WriteArray(image[:, :, 0])
    tif.GetRasterBand(2).WriteArray(image[:, :, 1])
    tif.GetRasterBand(3).WriteArray(image[:, :, 2])
    # Flush to disk & close.
    tif.FlushCache()
    tif = None


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        usage="%(prog)s [OPTION] [FILE]", description="Convert QCT files."
    )
    parser.add_argument(
        "-s", "--scale", action="store", dest="scale", default=None, type=float
    )
    parser.add_argument("files", nargs="*")

    args = parser.parse_args()

    for path_qct in args.files:
        path_base = os.path.splitext(os.path.basename(path_qct))[0]

        with open(path_qct, "rb") as f:
            qct(f, path_base, args.scale)
