"""
BSD 2-Clause License

Copyright (c) 2024, Centre Borelli
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import annotations

import concurrent.futures
import multiprocessing
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Iterable, Mapping, Optional, Sequence

import numpy as np
from numpy.typing import NDArray
from scipy.sparse import csr_array
from scipy.spatial import Delaunay

_E = tuple[np.int32, np.int32]


def edges_to_start_end(edges: Iterable[_E]) -> tuple[list[np.int32], list[np.int32]]:
    start = [e[0] for e in edges]
    end = [e[1] for e in edges]
    return start, end


@dataclass(frozen=True)
class WrapAmbiguityAndGrad:
    wrapping_amb: int
    grad_of_wrapped: float


def invert(wamb_grad: WrapAmbiguityAndGrad) -> WrapAmbiguityAndGrad:
    return WrapAmbiguityAndGrad(-wamb_grad.wrapping_amb, -wamb_grad.grad_of_wrapped)


def compute_wamb_grad_per_edge(
    sorted_half_edges: Iterable[_E], phase: NDArray[np.float64], h: float
) -> dict[_E, WrapAmbiguityAndGrad]:
    """
    Create wrapped gradient of wrapped signal. For edge (x, y), do grad=(y - x).
    x and y can be either in [-h, h) or in (-h, h]
    thus grad is in (-2h,2h), so wrap(grad) = grad + 2nh, n = -1, 0 or 1
    We store grad and n in WrapAmbiguityAndGrad.

    For residue computation, since we integrate over cycles, the grad will disappear
    and only n will be used.
    keeping everything integral, things will go faster and rounding errors are avoided
    """

    start, end = edges_to_start_end(sorted_half_edges)

    diff = phase[end] - phase[start]
    wrapped = np.zeros((len(start),), dtype=int)
    wrapped[diff > h] = -1
    wrapped[diff < -h] = 1

    # Creating the dictionnary here takes most of the time, will be remove if avoidable
    wamb_grad_per_edge = {
        (x, y): WrapAmbiguityAndGrad(w, g)
        for x, y, w, g in zip(start, end, wrapped, diff)
    }

    return wamb_grad_per_edge


def get_edge_wamb_grad(
    wamb_grad_per_edge: dict[_E, WrapAmbiguityAndGrad], edge: _E
) -> WrapAmbiguityAndGrad:
    """
    Returns value corresponding to end - start.
    """
    start, end = edge

    if start > end:
        key = (end, start)
        wamb_grad = wamb_grad_per_edge[key]
        return invert(wamb_grad)

    else:
        key = (start, end)
        wamb_grad = wamb_grad_per_edge[key]
        return wamb_grad


def get_sorted_half_edges(Adj: csr_array) -> list[_E]:
    start, end = Adj.nonzero()
    sorted_half_edges = [(s, e) for s, e in zip(start, end) if s < e]
    return sorted_half_edges


class ConstraintRightSide(ABC):
    @abstractmethod
    def get_at_edge(self, edge: _E) -> float:
        """
        Get the right side of the constraint for edge (x, y),
        S.t.
        on cycles: \sum \delta_xy = \sum constrain_right_side_xy.
        Return type is actually either int or float.
        """
        pass


@dataclass(frozen=True)
class WambConstraintRightSide(ConstraintRightSide):
    wamb_grad_per_edge: dict[_E, WrapAmbiguityAndGrad]

    def get_at_edge(self, edge: _E) -> int:
        """Return type is actually either int or float."""
        return get_edge_wamb_grad(self.wamb_grad_per_edge, edge).wrapping_amb


class DeltaProvider(ABC):
    @abstractmethod
    def get_delta(self, e: _E) -> float:
        """
        Get delta of edge e = (x, y)
        S.t.
        on cycles: \sum \delta_xy = \sum constrain_right_side_xy.
        Return type is actually either int or float.
        """
        pass


class UnwGradProvider(ABC):
    @abstractmethod
    def get_at_edge(self, edge: _E) -> float:
        """
        Get the unwrapped gradient for edge (x, y) after correction.
        Depending on the underlying equation used to define the constraint,
        namely its right side and the deltas, the implementation of this function may vary.
        """
        pass


@dataclass(frozen=True)
class WambUnwGradProvider(UnwGradProvider):
    """
    Equation described by:
        f_y - f_x + 2 h \delta_xy = f'_y - f'_x + 2 h wamb_xy
    """

    wamb_grad_per_edge: dict[_E, WrapAmbiguityAndGrad]
    delta_provider: DeltaProvider
    h: float

    def get_at_edge(self, edge: _E) -> float:
        """
        Get the unwrapped gradient for edge (x, y) after correction.
        """
        delta_xy = self.delta_provider.get_delta(edge)

        wamb_grad_xy = get_edge_wamb_grad(self.wamb_grad_per_edge, edge)
        unw_grad = (
            wamb_grad_xy.grad_of_wrapped
            + (wamb_grad_xy.wrapping_amb - delta_xy) * 2 * self.h
        )
        return unw_grad


def integrate(
    nodes_predecessors_order: tuple[NDArray[np.int32], NDArray[np.int32]],
    unw_grad_provider: UnwGradProvider,
) -> NDArray[np.float64]:
    nodes, predecessor = nodes_predecessors_order
    num_points = len(nodes)
    # Define function fct to be unwrapped
    # In particular, root = fct[nodes[0]] is set to 0
    fct = np.zeros(num_points, dtype=np.float64)

    for index in range(1, num_points):
        y = nodes[index]
        x = predecessor[nodes[index]]

        unw_grad_xy = unw_grad_provider.get_at_edge((x, y))
        fct[y] = fct[x] + unw_grad_xy

    return fct


def delaunay_to_AdjMat(D: Delaunay) -> csr_array:
    indptr, indices = D.vertex_neighbor_vertices
    data = np.ones_like(indices, dtype=bool)
    n = len(D.points)
    Adj = csr_array((data, indices, indptr), shape=(n, n))
    return Adj


def interval_phase_assertion(
    phase: NDArray[np.float64], h: float, prefix_msg: str = ""
) -> None:
    min_phase, max_phase = np.nanmin(phase), np.nanmax(phase)
    in_interval = (min_phase >= -h and max_phase < h) or (
        min_phase > -h and max_phase <= h
    )
    suffix_msg = f"""Input must be in either the [-h,h) or the (-h, h] interval,
    with h={h}, but min_phase={min_phase} and max_phase={max_phase}"""
    assert in_interval, f"{prefix_msg}: {suffix_msg}"


def get_adj_from_sorted_half_edges(
    sorted_half_edges: Iterable[_E], num_nodes: Optional[int] = None, *, symmetric=True
) -> csr_array:
    start, end = edges_to_start_end(sorted_half_edges)

    if symmetric:
        num_edges = 2 * len(start)
        row, col = start + end, end + start
    else:
        num_edges = len(start)
        row, col = start, end

    data = np.ones(num_edges, dtype=bool)
    shape = None if num_nodes is None else (num_nodes, num_nodes)
    Adj = csr_array((data, (row, col)), shape=shape, dtype=bool)
    return Adj


def get_elem_from_sorted_grad_dict(grad_per_edge: Mapping[_E, float], k: _E) -> float:
    if k[0] < k[1]:
        return grad_per_edge[k]
    else:
        return -grad_per_edge[(k[1], k[0])]


def sort_key(k: _E) -> _E:
    if k[0] < k[1]:
        return k
    else:
        return (k[1], k[0])


def get_elem_from_sorted_dict(dict_per_edge_sorted: Mapping[_E, float], k: _E) -> float:
    return dict_per_edge_sorted[sort_key(k)]


@dataclass(frozen=True)
class FloatConstraintRightSide(ConstraintRightSide):
    """
    Equation described by:
        f_y - f_x + \delta_xy = f'_xy
    """

    observed_grad_per_edge: Mapping[_E, float]
    """
    f'_xy
    """

    def get_at_edge(self, edge: _E) -> float:
        return get_elem_from_sorted_grad_dict(self.observed_grad_per_edge, edge)


@dataclass(frozen=True)
class FloatUnwGradProvider(UnwGradProvider):
    """
    Equation described by:
        f_y - f_x + \delta_xy = f'_xy
    """

    observed_grad_per_edge: dict[_E, float]
    delta_provider: DeltaProvider

    def get_at_edge(self, edge: _E) -> float:
        """
        Get the unwrapped gradient f_y - f_x for edge (x, y) after correction.
        """
        unw_grad = get_elem_from_sorted_grad_dict(
            self.observed_grad_per_edge, edge
        ) - self.delta_provider.get_delta(edge)
        return unw_grad


class ResultProvider(ABC):
    result_shape: tuple[int, ...]

    @abstractmethod
    def get_result(self, in_array: NDArray[np.float64]) -> NDArray[np.float64]:
        pass


def get_result_on_stack(
    result_provider: ResultProvider,
    arrays_list: list[NDArray[np.float64]],
    ncpu: int = 1,
):
    if ncpu == 1:
        result_list = [result_provider.get_result(arr) for arr in arrays_list]
    else:
        mp_context = multiprocessing.get_context("spawn")
        with concurrent.futures.ProcessPoolExecutor(
            max_workers=ncpu, mp_context=mp_context
        ) as executor:
            future_to_date = {
                executor.submit(result_provider.get_result, arr): i
                for i, arr in enumerate(arrays_list)
            }

            result_arr = np.zeros(
                (len(arrays_list), *result_provider.result_shape), dtype=np.float64
            )

            for future in concurrent.futures.as_completed(future_to_date):
                i = future_to_date[future]
                try:
                    data = future.result()
                except Exception as exc:
                    print(f"Date {i} generated an exception: {exc}")
                else:
                    result_arr[i] = data
            result_list = list(result_arr)

    return result_list


def array_and_edges_to_dict_per_edge(
    arr: NDArray[Any], edges: Sequence[_E]
) -> dict[_E, Any]:
    assert len(arr) == len(
        edges
    ), "Must have the same number of edges and array elements"
    return dict(zip(edges, arr))


@dataclass(frozen=True)
class DictPerEdgeProvider:
    sorted_half_edges: list[_E]
    flipped_mask: NDArray[np.bool_]

    def get_grad_dict(self, grad: NDArray[np.float64]) -> dict[_E, float]:
        sorted_grad = grad.copy()
        # Change the gradient sign on places where the edge was flipped
        sorted_grad[self.flipped_mask] = -sorted_grad[self.flipped_mask]
        return array_and_edges_to_dict_per_edge(sorted_grad, self.sorted_half_edges)

    def get_weight_dict(
        self, weight: Optional[NDArray[np.float64]]
    ) -> Optional[dict[_E, float]]:
        # here the weight is not affected by the flipping of the sorting
        if weight is None:
            return None
        else:
            return array_and_edges_to_dict_per_edge(weight, self.sorted_half_edges)

    @staticmethod
    def from_edges_array(edges: NDArray[np.int32]) -> DictPerEdgeProvider:
        flipped_mask = edges[:, 0] > edges[:, 1]
        sorted_edges = edges.copy()
        # change start, end of edge
        sorted_edges[flipped_mask] = sorted_edges[flipped_mask][:, [1, 0]]
        """
        Now assert no repeated edges, i.e. for now assertion error raised if otherwise 
        Another strategy would be to store indices or repeated entries given by np.unique
        Then check if the corresponding gradients are equal
        An alternative would be to take one of the entries without checking...
        """
        sorted_edges_list = [(e[0], e[1]) for e in sorted_edges]
        sorted_edges_set = set(sorted_edges_list)

        msg = "Repeated entries for the same edge are not allowed."
        assert len(sorted_edges_set) == len(sorted_edges_list), msg

        return DictPerEdgeProvider(sorted_edges_list, flipped_mask)
