#!/usr/bin/env python3

import math
import time
from bisect import bisect_left
from dataclasses import dataclass
from functools import cmp_to_key, cache, partial
from random import triangular
from typing import List, Union


OCTREE_NODE_SIZE = 16 * 8  # 14 numbers of 8 byte each, but accounting for padding in an array
POINT_SIZE = 3 * 8  # three double-precision coordinates
TEMPORARY_COUNT_ELEMENT_SIZE = 8 * 8  # up to seven integers, and one holding the length
INT_MIN = -1000000


def generate_points(root_node_lower_bound, root_node_level, total_count, max_equal_points_count=1):
    upper_bound = [root_node_lower_bound[i] + 2**root_node_level - 1.0e-10 for i in range(3)]
    modes = [root_node_lower_bound[i] + 2**root_node_level - 2 ** (root_node_level - i) for i in range(3)]
    triangular_parameters = list(zip(root_node_lower_bound, upper_bound, modes))
    max_equal_points_count = max(1, min(max_equal_points_count, math.floor(math.sqrt(2 * total_count + 9 / 4) - 0.5)))
    random_points = [
        tuple(triangular(*p) for p in triangular_parameters)
        for _ in range(total_count - max_equal_points_count * (max_equal_points_count - 1) // 2)
    ]
    return random_points + [random_points[i] for i in range(max_equal_points_count - 1) for _ in range(i + 1)]


def xor_msb(first, second):
    if first == second:
        return INT_MIN
    if first == 0:
        return math.frexp(second)[1] - 1
    if second == 0:
        return math.frexp(first)[1] - 1
    first_frexp = math.frexp(first)
    second_frexp = math.frexp(second)
    if first_frexp[1] == second_frexp[1]:
        a = 64 - first_frexp[1]
        return (
            math.frexp(
                int(math.ldexp(first_frexp[0], first_frexp[1] + a))
                ^ int(math.ldexp(second_frexp[0], second_frexp[1] + a))
            )[1]
            - a
            - 1
        )
    return max(first_frexp[1], second_frexp[1]) - 1


def morton_compare(left, right):
    i_max = 0
    b_max = INT_MIN
    for i in range(3):
        b = xor_msb(left[i], right[i])
        if b > b_max:
            i_max = i
            b_max = b
    return left[i_max] - right[i_max]


def mbs(first_point, second_point):
    return max(xor_msb(first, second) for first, second in zip(first_point, second_point))


def mbi(level, point):
    return sum(2 ** (2 - i) * (math.floor(point[i] / 2**level) % 2) for i in range(3))


@dataclass
class OctreeNode:
    level: int
    lower_bound: List[float]
    points_start: Union[int, None] = None
    point_count: Union[int, None] = None
    children: Union[list, None] = None

    @classmethod
    def from_parent(cls, parent, point_start=None, point_count=None):
        level = parent.level - 1
        edge_length = 2.0**level
        child_index = len(parent.children)
        lower_bound = [parent.lower_bound[i] + (child_index // 2 ** (2 - i)) % 2 * edge_length for i in range(3)]
        return OctreeNode(level=level, lower_bound=lower_bound, points_start=point_start, point_count=point_count)

    def contains(self, point):
        edge_length = 2.0**self.level
        return all(self.lower_bound[i] <= point[i] < self.lower_bound[i] + edge_length for i in range(3))

    def to_string(self, root_node_level=None):
        if root_node_level is None:
            root_node_level = self.level
        return (
            "  |" * (root_node_level - self.level)
            + f"({self.point_count})\n"
            + "".join(child.to_string(root_node_level) for child in (self.children or []))
        )


class OctreeBuilder:
    """
    implements the method by Kontkanen et al.
    """
    def __init__(self, root_node_lower_bound, root_node_level, points, leaf_max):
        self.leaf_max = leaf_max
        self.stack = []
        self.root_node = OctreeNode(level=root_node_level, lower_bound=root_node_lower_bound)  # root node
        self.current_node = self.root_node
        self.current_point_index = 0
        self.max_stack_length = 0
        self.build(points, leaf_max)

    def push_current_node_to_stack(self):
        self.current_node.children = []
        self.stack.append(self.current_node)
        self.current_node = OctreeNode.from_parent(self.current_node)
        self.max_stack_length = max(self.max_stack_length, len(self.stack))

    def set_next_current_node(self):
        if self.stack:
            parent = self.stack[-1]
            parent.children.append(self.current_node)
            if len(parent.children) < 8:
                self.current_node = OctreeNode.from_parent(parent)
            else:
                self.finalize_node_from_stack()

    def finalize_current_node(self, point_count):
        self.current_node.points_start = self.current_point_index
        self.current_node.point_count = point_count
        self.current_point_index += point_count
        self.set_next_current_node()

    def finalize_node_from_stack(self):
        self.current_node = self.stack.pop()
        self.current_node.children.extend(
            OctreeNode.from_parent(self.current_node, self.current_point_index, 0)
            for _ in range(8 - len(self.current_node.children))
        )
        self.current_node.points_start = self.current_node.children[0].points_start
        self.current_node.point_count = sum(child.point_count for child in self.current_node.children)
        self.set_next_current_node()

    def finish(self):
        while self.stack:
            self.finalize_node_from_stack()

    def build(self, points, leaf_max):
        first_point_index = 0
        while first_point_index < len(points):
            if len(points) - first_point_index > leaf_max and self.current_node.contains(
                points[first_point_index + leaf_max]
            ):
                self.push_current_node_to_stack()
            else:
                index_of_first_outside = bisect_left(
                    a=points,
                    x=1,
                    lo=first_point_index,
                    hi=min(first_point_index + leaf_max, len(points)),
                    key=lambda x: int(not self.current_node.contains(x)),
                )
                self.finalize_current_node(index_of_first_outside - first_point_index)
                first_point_index = index_of_first_outside
        self.finish()

    def memory_used(self):
        memory = self.max_stack_length * OCTREE_NODE_SIZE + self.leaf_max * POINT_SIZE
        return memory


class LimitedMemoryOctreeBuilder(OctreeBuilder):
    """
    implements the method by Fischer et al.
    """
    def __init__(self, root_node_lower_bound, root_node_level, points, leaf_max, chunk_size):
        self.chunk_size = chunk_size
        self.max_temporary_counts_length = 0
        super().__init__(root_node_lower_bound, root_node_level, points, leaf_max)

    def build(self, points, leaf_max):
        def push_to_stack():
            nonlocal temporary_counts
            nonlocal current_node_point_count
            nonlocal current_node_level
            self.push_current_node_to_stack()
            for point_count in temporary_counts[root_node_level - current_node_level]:
                self.finalize_current_node(point_count)
            current_node_point_count -= sum(temporary_counts[root_node_level - current_node_level])
            current_node_level -= 1

        temporary_counts = []
        root_node_level = self.current_node.level
        temporary_counts_level = root_node_level
        current_node_level = self.current_node.level  # B
        last_point = points[0]  # p_last
        current_node_point_count = 1  # n
        chunk_first_point_index = 1
        while chunk_first_point_index < len(points):
            next_chunk_first_point_index = min(chunk_first_point_index + self.chunk_size, len(points))
            temporary_node_level = current_node_level  # C
            while chunk_first_point_index < next_chunk_first_point_index:
                desired_index = (
                    chunk_first_point_index
                    + leaf_max
                    - current_node_point_count
                    + sum(
                        sum(temporary_counts[level])
                        for level in range(root_node_level - current_node_level, root_node_level - temporary_node_level)
                    )
                )
                considered_index = min(desired_index, next_chunk_first_point_index - 1)  # l
                considered_point = points[considered_index]
                if last_point == points[next_chunk_first_point_index - 1] or (
                    next_chunk_first_point_index == len(points)
                    and considered_index < desired_index
                    and mbs(last_point, considered_point) < temporary_node_level
                ):
                    current_node_point_count += next_chunk_first_point_index - chunk_first_point_index
                    chunk_first_point_index = next_chunk_first_point_index
                    while current_node_point_count > leaf_max:
                        push_to_stack()
                elif last_point == considered_point or mbs(last_point, considered_point) < temporary_node_level:
                    temporary_node_level -= 1
                    if temporary_counts_level > temporary_node_level:
                        temporary_counts_level -= 1
                        if len(temporary_counts) < root_node_level - temporary_counts_level:
                            temporary_counts.append([])
                            self.max_temporary_counts_length = max(
                                self.max_temporary_counts_length, len(temporary_counts)
                            )
                        temporary_counts[root_node_level - temporary_counts_level - 1] = [0] * mbi(
                            temporary_node_level, last_point
                        )
                else:
                    temporary_counts_level = temporary_node_level
                    cached_mbs = cache(partial(mbs, last_point))
                    index_of_first_outside = bisect_left(
                        a=points,
                        x=temporary_node_level,
                        lo=chunk_first_point_index,
                        hi=considered_index,
                        key=cached_mbs,
                    )
                    temporary_node_level = cached_mbs(points[index_of_first_outside])
                    empty_nodes_count = (
                        mbi(temporary_node_level, points[index_of_first_outside])
                        - mbi(temporary_node_level, last_point)
                        - 1
                    )
                    assert empty_nodes_count >= 0
                    last_point = points[index_of_first_outside]
                    current_node_point_count += index_of_first_outside + 1 - chunk_first_point_index
                    chunk_first_point_index = index_of_first_outside + 1
                    while current_node_point_count - 1 > leaf_max or (
                        current_node_level > temporary_node_level and current_node_point_count > leaf_max
                    ):
                        push_to_stack()
                    if current_node_level > temporary_node_level:
                        assert len(temporary_counts[root_node_level - temporary_counts_level - 1]) < 7
                        temporary_counts[root_node_level - temporary_counts_level - 1].append(
                            current_node_point_count
                            - sum(
                                sum(temporary_counts[level])
                                for level in range(
                                    root_node_level - current_node_level, root_node_level - temporary_counts_level
                                )
                            )
                            - 1
                        )
                        temporary_counts[root_node_level - temporary_counts_level - 1] += [0] * empty_nodes_count
                    else:
                        self.finalize_current_node(point_count=current_node_point_count - 1)
                        while self.current_node.level < temporary_node_level:
                            self.finalize_node_from_stack()
                        for _ in range(empty_nodes_count):
                            self.finalize_current_node(point_count=0)
                        current_node_point_count = 1
                        current_node_level = temporary_node_level
                        temporary_counts_level = current_node_level
        self.finalize_current_node(point_count=current_node_point_count)
        self.finish()

    def memory_used(self):
        memory = (
            self.max_stack_length * OCTREE_NODE_SIZE
            + self.max_temporary_counts_length * TEMPORARY_COUNT_ELEMENT_SIZE
            + self.chunk_size * POINT_SIZE
        )
        return memory


def run_octree_construction(total_count, leaf_max_list, chunk_sizes):
    print(f"{total_count=}")
    root_node_lower_bound = [0.0, 0.0, 0.0]
    root_node_level = 3
    max_equal_points_count = 1
    number_of_point_sets = 5
    iterations_over_same_point_set = 3

    time_factor = 1e9 * number_of_point_sets
    statistics = {leaf_max: [(0, 0)] * (len(chunk_sizes) + 1) for leaf_max in leaf_max_list}
    for point_set_index in range(number_of_point_sets):
        print(point_set_index, end="")
        points = generate_points(
            root_node_lower_bound,
            root_node_level,
            total_count,
            max_equal_points_count=min(max_equal_points_count, *leaf_max_list),
        )
        print(", sorting ", end="")
        before = time.perf_counter_ns()
        if total_count > 10**7:
            points.sort(key=cmp_to_key(morton_compare))
        else:
            points = sorted(points, key=cmp_to_key(morton_compare))
        print(f"took {(time.perf_counter_ns() - before)/1e9} seconds.")
        for leaf_max in leaf_max_list:
            print("+", end="")
            min_times = [10**6 * time_factor] * (len(chunk_sizes) + 1)
            memory_usages = []
            for __ in range(iterations_over_same_point_set):
                print("|", end="")
                before = time.perf_counter_ns()
                kontkanen_builder = OctreeBuilder(root_node_lower_bound, root_node_level, points, leaf_max)
                min_times[0] = min(min_times[0], time.perf_counter_ns() - before)
                memory_usages = [kontkanen_builder.memory_used()]
                for index, chunk_size in enumerate(chunk_sizes):
                    print(".", end="")
                    before = time.perf_counter_ns()
                    limited_memory_builder = LimitedMemoryOctreeBuilder(
                        root_node_lower_bound, root_node_level, points, leaf_max, chunk_size
                    )
                    min_times[index + 1] = min(min_times[index + 1], time.perf_counter_ns() - before)
                    memory_usages.append(limited_memory_builder.memory_used())

                    assert kontkanen_builder.root_node == limited_memory_builder.root_node
            statistics[leaf_max] = [
                (total_time + min_time, total_memory + memory_usage)
                for (total_time, total_memory), min_time, memory_usage in zip(
                    statistics[leaf_max], min_times, memory_usages
                )
            ]
        print("")
        del points

    return {
        leaf_max: [
            (total_time / time_factor, total_memory / number_of_point_sets)
            for (total_time, total_memory) in statistics[leaf_max]
        ]
        for leaf_max in leaf_max_list
    }


def print_statistics(statistics, chunk_size_list):
    for leaf_max, statistics_for_leaf_max in statistics.items():
        print(f"{leaf_max=}:")
        print(
            f"Kontkanen: {statistics_for_leaf_max[0]}, "
            + ", ".join(f"{chunk_size=}: {t}" for chunk_size, t in zip(chunk_size_list, statistics_for_leaf_max[1:]))
        )
        for index, time_or_memory in enumerate(("time", "memory")):
            print(f"{time_or_memory}:")
            print(
                f"\\addplot[] coordinates {{({min(chunk_size_list)},{statistics_for_leaf_max[0][index]})"
                f"({max(chunk_size_list)},{statistics_for_leaf_max[0][index]})}};"
            )
            print(
                f"\\addplot[] coordinates {{(\n"
                + "".join(
                    f"({chunk_size},{statistics_for_chunk_size[index]})\n"
                    for chunk_size, statistics_for_chunk_size in zip(chunk_size_list, statistics_for_leaf_max[1:])
                )
                + "};"
            )


def octree_construction():
    for total_count_exp in [8]:
        total_count = 10**total_count_exp
        leaf_max_list = [
            10**leaf_max_exp for leaf_max_exp in reversed(range(max(1, total_count_exp - 4), total_count_exp))
        ]
        chunk_size_list = sorted(
            [10**i for i in range(max(0, total_count_exp - 7), total_count_exp + 1)]
            + [math.floor(math.sqrt(10) * 10**i) for i in range(max(0, total_count_exp - 7), total_count_exp)]
        )
        statistics = run_octree_construction(total_count, leaf_max_list, chunk_size_list)
        print_statistics(statistics, chunk_size_list)


if __name__ == "__main__":
    octree_construction()
