"""
BSD 2-Clause License

Copyright (c) 2024, Centre Borelli
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from __future__ import annotations

import time
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass
from typing import Iterable, Optional

import numpy as np
import pulp
from delaunay_unwrap.utils import (
    _E,
    ConstraintRightSide,
    DeltaProvider,
    DictPerEdgeProvider,
    FloatConstraintRightSide,
    FloatUnwGradProvider,
    ResultProvider,
    WambConstraintRightSide,
    WambUnwGradProvider,
    compute_wamb_grad_per_edge,
    delaunay_to_AdjMat,
    get_adj_from_sorted_half_edges,
    get_elem_from_sorted_dict,
    get_result_on_stack,
    get_sorted_half_edges,
    integrate,
    interval_phase_assertion,
)
from numpy.typing import NDArray
from scipy.sparse import csr_array, dia_array
from scipy.sparse.csgraph import breadth_first_order
from scipy.spatial import Delaunay

# %%
# This is an implementation of algorithm 1.


# Provides neighboring nodes of i
class NeighborProvider(ABC):
    @abstractmethod
    def get_neighbors(self, i: np.int32) -> Iterable[np.int32]:
        pass


# Neighbor provider in CSR format
@dataclass(frozen=True)
class CSRAdjNeighborProvider(NeighborProvider):
    indptr: NDArray[np.int32]
    indices: NDArray[np.int32]

    def get_neighbors(self, i: np.int32) -> NDArray[np.int32]:
        neigh = self.indices[self.indptr[i] : self.indptr[i + 1]]
        return neigh


try:
    from networkx.classes.graph import Graph
except ImportError:
    pass
else:

    @dataclass(frozen=True)
    class NxNeighborProvider(NeighborProvider):
        G: Graph

        def get_neighbors(self, i: np.int32) -> list[np.int32]:
            neigh = list(self.G[i])
            return neigh


# This remove self-loops which corresponds to nonzero diagonal entries of the adjacency matrix
def custom_remove_diag(x: csr_array) -> csr_array:
    return x - dia_array((x.diagonal()[None, :], [0]), shape=(x.shape))


# This corresponds to the computation of AdjRed of algorithm 1.
def get_AdjRed(Adj: csr_array, k: int, *, remove_self_edges=False) -> csr_array:
    AdjRed = Adj.copy()
    for i in range(k):
        AdjRed = Adj + AdjRed.dot(Adj)

    if remove_self_edges:
        # Avoid self edges
        AdjRed = custom_remove_diag(AdjRed)

    return AdjRed


# This corresponds to growing a cluster in algorithm 3.
# If an intersection occurs between the two clusters, its location is returned
def breadth_first_grow(
    tovisitneighbors: list[np.int32],
    neighbor_provider: NeighborProvider,
    visited: list[np.int32],
    pred: dict[np.int32, np.int32],
    visited_other: list[np.int32],
) -> tuple[list[np.int32], Optional[np.int32]]:
    where: Optional[np.int32] = None
    tovisitneighborsbis = []
    for i in tovisitneighbors:
        neigh = neighbor_provider.get_neighbors(i)
        for j in neigh:
            if j not in visited:
                visited.append(j)
                tovisitneighborsbis.append(j)
                pred[j] = i

                if j in visited_other:
                    where = j
                    return tovisitneighborsbis, where

    return tovisitneighborsbis, where


# Returns the unique path from start to root in a tree
def path_to_root(
    pred: dict[np.int32, np.int32], start: np.int32, root: np.int32
) -> list[np.int32]:
    w = start
    path = [w]
    while w != root:
        w = pred[w]
        path.append(w)
    return path


# This is an implementation of algorithm 3.
def GetShortestPath(
    neighbor_provider: NeighborProvider, E0: np.int32, E1: np.int32
) -> list[np.int32]:
    pred0: dict[np.int32, np.int32] = {}
    pred1: dict[np.int32, np.int32] = {}
    tovisitneighbors0 = [E0]
    tovisitneighbors1 = [E1]
    visited0 = [E0]
    visited1 = [E1]
    where: Optional[np.int32] = None
    while where is None:  # do two breadthfirsts until intersection occurs
        tovisitneighbors0, where = breadth_first_grow(
            tovisitneighbors0, neighbor_provider, visited0, pred0, visited1
        )

        if where is None:
            tovisitneighbors1, where = breadth_first_grow(
                tovisitneighbors1, neighbor_provider, visited1, pred1, visited0
            )

    # Now intersection has occurred, reconstruct path
    path = path_to_root(pred0, where, E0)
    path = path[:0:-1]  # go from E0 to intersection (not included)
    path += path_to_root(pred1, where, E1)

    return path


# Returns the cycle basis of the Delaunay simplices
def simplices_to_basis(simplices: NDArray[np.int32]) -> list[list[np.int32]]:
    basis = [list(S) for S in simplices]
    return basis


# This is an implementation of algorithm 2.
def GetSmallBasis(
    simplices: NDArray[np.int32],
    Adj: csr_array,
    AdjRed: csr_array,
) -> list[list[np.int32]]:
    # Delaunay triangles
    basis = simplices_to_basis(simplices)

    A = AdjRed - Adj
    # Additional redundant edges
    start, end = A.nonzero()

    # We use Adj (orig Delaunay) edges for shortest path search
    neighbor_provider = CSRAdjNeighborProvider(Adj.indptr, Adj.indices)

    for e0, e1 in zip(start, end):
        # Do it once per edge, do not repeat for reverse edge
        if (
            e0 < e1
        ):  # find shortest path in Delaunay between e0 and e1 using two breadthfirst searches until they meet
            cycle = GetShortestPath(neighbor_provider, e0, e1)
            basis.append(cycle)

    return basis


def visit_predecessor_and_intersect(
    curr_node: np.int32,
    other_leaf: np.int32,
    visited_curr_branch: list[np.int32],
    visited_other_branch: list[np.int32],
    used: dict[np.int32, set[np.int32]],
) -> Optional[list[np.int32]]:
    # grow curr branch
    visited_curr_branch.append(curr_node)

    cycle = None
    # check if we arrive to a previous cycle node
    # This check is to simplify the cycle,
    # so technically it won't be a fundamental cycle basis if we include it
    if curr_node in used[other_leaf] or other_leaf in used[curr_node]:
        # curr_node - other_leaf cycle previously found, stop here
        # cycle is path bw curr_leaf --- curr_node + other_leaf
        cycle = visited_curr_branch + [other_leaf]
        return cycle

    # check if interesects other branch
    if curr_node in visited_other_branch:
        # list going from other_leaf to curr_node excluded
        path_to_curr_node_other = visited_other_branch[
            : visited_other_branch.index(curr_node)
        ]
        # add with list going from curr_leaf to curr_node after inversion
        cycle = visited_curr_branch + path_to_curr_node_other[::-1]
        return cycle

    return cycle


def find_cycle(
    z: np.int32,
    nbr: np.int32,
    root: np.int32,
    pred: dict[np.int32, np.int32],
    used: dict[np.int32, set[np.int32]],
) -> list[np.int32]:
    visited_z = [z]
    z_node = z
    visited_nbr = [nbr]
    nbr_node = nbr

    while True:
        if z_node != root:
            z_node = pred[z_node]
            cycle = visit_predecessor_and_intersect(
                z_node, nbr, visited_z, visited_nbr, used
            )
            if cycle is not None:
                break

        if nbr_node != root:
            nbr_node = pred[nbr_node]
            cycle = visit_predecessor_and_intersect(
                nbr_node, z, visited_nbr, visited_z, used
            )
            if cycle is not None:
                break
    return cycle


# This is an implementation of algorithm 4.
def GetFundamentalCycleBasis(
    nodes: Iterable[np.int32],
    neighbor_provider: NeighborProvider,
    root: Optional[np.int32] = None,
) -> tuple[list[list[np.int32]], dict[np.int32, np.int32]]:
    """
    Warning: This is actually not a fundamental cycle basis, because we simplify some cycles in find_cycle
    """
    # We copy because we are going to pop nodes from here
    gnodes = dict.fromkeys(np.array(nodes, dtype=np.int32))
    cycles = []

    while gnodes:  # loop over connected components
        if root is None:
            root = gnodes.popitem()[0]
        else:
            # It seems to slow down heavily if we don't cast
            root = np.int32(root)

        queue = deque([root])
        pred = {root: root}
        used: dict[np.int32, set[np.int32]] = {root: set()}
        while queue:  # walk the spanning tree finding cycles
            z = queue.popleft()  # use first-in to get breadthfirst graph
            zused = used[z]
            neighbors = neighbor_provider.get_neighbors(z)
            for nbr in neighbors:
                if nbr not in used:  # new node
                    pred[nbr] = z
                    queue.append(nbr)
                    used[nbr] = {z}
                elif nbr == z:  # self loops
                    cycles.append([z])
                elif nbr not in zused:  # found a cycle
                    cycle = find_cycle(z, nbr, root, pred, used)
                    cycles.append(cycle)
                    used[nbr].add(z)

        for node in pred:
            gnodes.pop(node, None)
        root = None

    return cycles, pred


@dataclass(frozen=True)
class GraphStructure:
    sorted_half_edges: Iterable[_E]
    basis: list[list[np.int32]]
    nodes_predecessors_order: tuple[NDArray[np.int32], NDArray[np.int32]]


# Returns properties of graph under consideration (basis, adjacency matrix, redundant adjacency matrix, edges, breadth-first list of nodes)
def get_graph_structure(
    pts: NDArray[np.int64],
    k: int,
    useSmallBasis: bool = True,
    integration_ref_id: np.int32 = np.int32(0),
) -> GraphStructure:
    # Get Delaunay triangulation of points
    # Apparently all triangles are counterclockwise, no need to account for orientation
    D = Delaunay(pts, qhull_options="QJ")

    # Delaunay to adj is very quick, because adj csr structure already stored in D
    Adj = delaunay_to_AdjMat(D)

    AdjRed = get_AdjRed(Adj, k, remove_self_edges=True)

    sorted_half_edges = get_sorted_half_edges(AdjRed)

    nodes_predecessors_order = breadth_first_order(
        Adj, integration_ref_id, directed=False
    )

    basis = get_cycle_basis(D.simplices, Adj, AdjRed, k, useSmallBasis=useSmallBasis)

    return GraphStructure(sorted_half_edges, basis, nodes_predecessors_order)


# Returns a string corresponding to an edge (used for LP variable names)
def edge_to_str(edge: _E) -> str:
    return f"{edge[0]}_{edge[1]}"


# Returns edge ji given edge ij (reverse edge)
def reverse_edge(edge: _E) -> _E:
    return (edge[1], edge[0])


# Returns a cycle basis according to user's parameter choices
def get_cycle_basis(
    simplices: NDArray[np.int32],
    Adj: csr_array,
    AdjRed: csr_array,
    k: int,
    useSmallBasis: bool = True,
) -> list[list[np.int32]]:
    if useSmallBasis:
        if k == 0:
            basis = simplices_to_basis(simplices)
        else:
            basis = GetSmallBasis(simplices, Adj, AdjRed)
    else:
        num_nodes = AdjRed.shape[0]
        nodes = np.arange(0, num_nodes, dtype=np.int32)
        neighbor_provider = CSRAdjNeighborProvider(AdjRed.indptr, AdjRed.indices)
        basis, _ = GetFundamentalCycleBasis(nodes, neighbor_provider)

    return basis


# Provides the optimal deltas to correct the wrapped gradient into a true gradient
def get_optimum_deltas(
    sorted_half_edges: Iterable[_E],
    basis: list[list[np.int32]],
    constraint_right_side: ConstraintRightSide,
    capacity: Optional[float] = None,
    confidence_weight_per_edge: Optional[dict[_E, float]] = None,
) -> dict[_E, pulp.LpVariable]:
    # Define an LP problem
    prob = pulp.LpProblem("LP_problem", pulp.LpMinimize)

    # Create LP variables, delta+/- >= 0
    variables = {
        e: pulp.LpVariable(edge_to_str(e), 0, capacity, pulp.LpContinuous)
        for she in sorted_half_edges
        for e in (she, reverse_edge(she))
    }

    if confidence_weight_per_edge is not None:
        # Objective to minimize = sum cij |deltaij| = sum cij deltaij+/-
        affine_expr = [
            var * get_elem_from_sorted_dict(confidence_weight_per_edge, e)
            for e, var in variables.items()
        ]
    else:
        # Objective to minimize = sum |deltaij| = sum deltaij+/-
        affine_expr = list(variables.values())

    # Add objective
    prob += pulp.lpSum(affine_expr)

    for cycle in basis:
        # for each cycle, define a constraint
        closed = cycle + [cycle[0]]
        delta_term = [
            var
            for i in range(len(closed) - 1)
            for var in (
                variables[(closed[i], closed[i + 1])],
                -variables[(closed[i + 1], closed[i])],
            )
        ]
        right_side_terms = [
            constraint_right_side.get_at_edge((closed[i], closed[i + 1]))
            for i in range(len(closed) - 1)
        ]

        prob.addConstraint(pulp.lpSum(delta_term) == np.sum(right_side_terms))

    # Define the solver
    solver = pulp.PULP_CBC_CMD(mip=False, msg=False)
    # Solve the problem
    start = time.time()
    status = prob.solve(solver=solver)
    assert status == 1, f"Solution not found, exit status: {pulp.LpStatus[status]}"
    print(
        f"status:  {pulp.LpStatus[status]}, time for solving LP : ", time.time() - start
    )

    return variables


# Provides the deltas to correct the wrapped gradient into a true gradient
@dataclass(frozen=True)
class LPDeltaProvider(DeltaProvider):
    variables: dict[_E, pulp.LpVariable]
    eps_deviation_from_int: Optional[float] = None
    """
    If the delta should be integer, in this case we actually do a sanity check:
        abs(delta - round(delta)) < eps_deviation_from_int
    """

    def potential_round(self, delta_xy: float) -> float:
        if self.eps_deviation_from_int is not None:
            # It should be integer
            int_delta = round(delta_xy)
            assert (
                abs(delta_xy - int_delta) < self.eps_deviation_from_int
            ), f"""Integer condition broken: abs({
                delta_xy} - {int_delta}) >= {self.eps_deviation_from_int}"""
            delta_xy = int_delta

        return delta_xy

    def get_delta(self, e: _E) -> float:
        """
        Get delta of edge e = (x, y)
        S.t.
        on cycles: \sum \delta_xy = \sum constrain_right_side_xy.
        which is \sum (\delta_xy_plus - \delta_xy_minus) = \sum constrain_right_side_xy
        i.e.     \sum (\delta_xy_plus - \delta_yx_plus) = \sum constrain_right_side_xy

        Thus \delta_xy = \delta_xy_plus - \delta_yx_plus.
        """
        x, y = e
        # delta_xy_plus is in variables
        delta_xy_plus = pulp.value(self.variables[(x, y)])

        # delta_yx_plus is in variables, it is also delta_xy_minus
        delta_yx_plus = pulp.value(self.variables[(y, x)])

        delta_xy = self.potential_round(delta_xy_plus - delta_yx_plus)

        return delta_xy


def solve_unwrap_on_graph(
    graph_structure: GraphStructure,
    phase: NDArray[np.float64],
    h: float,
    capacity: Optional[float],
) -> NDArray[np.float64]:
    wamb_grad_per_edge = compute_wamb_grad_per_edge(
        graph_structure.sorted_half_edges, phase, h
    )

    variables = get_optimum_deltas(
        graph_structure.sorted_half_edges,
        graph_structure.basis,
        WambConstraintRightSide(wamb_grad_per_edge),
        capacity,
    )

    delta_provider = LPDeltaProvider(variables, eps_deviation_from_int=1e-3)

    fct = integrate(
        graph_structure.nodes_predecessors_order,
        WambUnwGradProvider(wamb_grad_per_edge, delta_provider, h),
    )

    return fct


@dataclass(frozen=True)
class LPUnwrapResultProvider(ResultProvider):
    result_shape: tuple[int, ...]
    graph_structure: GraphStructure
    h: float
    capacity: Optional[float] = None

    def get_result(self, in_array: NDArray[np.float64]) -> NDArray[np.float64]:
        return solve_unwrap_on_graph(
            self.graph_structure, in_array, self.h, self.capacity
        )


def solve_unwrap_stack_on_graph(
    graph_structure: GraphStructure,
    phases: list[NDArray[np.float64]],
    h: float,
    capacity: Optional[float] = None,
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    result_shape = phases[0].shape
    result_provider = LPUnwrapResultProvider(result_shape, graph_structure, h, capacity)
    unwrapped = get_result_on_stack(result_provider, phases, ncpu)
    return unwrapped


# This corresponds to algorithm 1.
def unwrap_delaunay_linprog_ts(
    pts: NDArray[np.int64],
    phases: list[NDArray[np.float64]],
    k: int,
    h: float = np.pi,
    b: Optional[float] = None,
    useSmallBasis: bool = True,
    integration_ref_id: np.int32 = np.int32(0),
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    for i, phase in enumerate(phases):
        interval_phase_assertion(phase, h, f"Problem for array of index {i}")
        assert len(phase) == len(pts), f"""Problem for array of index
                                        {i}: phase and pts must have the same length"""

    graph_structure = get_graph_structure(
        pts, k, useSmallBasis=useSmallBasis, integration_ref_id=integration_ref_id
    )

    # capacities on edge corrections, i.e. max_deltas
    if b is not None:
        capacity = (b + h) / (2 * h)
    else:
        capacity = None

    unwrapped = solve_unwrap_stack_on_graph(graph_structure, phases, h, capacity, ncpu)

    return unwrapped


def get_graph_structure_from_edges(
    sorted_half_edges: Iterable[_E], integration_ref_id: np.int32 = np.int32(0)
) -> GraphStructure:
    Adj = get_adj_from_sorted_half_edges(sorted_half_edges, symmetric=True)
    nodes_predecessors_order = breadth_first_order(
        Adj, integration_ref_id, directed=False
    )

    nodes = np.arange(0, Adj.shape[0], dtype=np.int32)
    neighbor_provider = CSRAdjNeighborProvider(Adj.indptr, Adj.indices)
    basis, _ = GetFundamentalCycleBasis(nodes, neighbor_provider)

    return GraphStructure(sorted_half_edges, basis, nodes_predecessors_order)


def solve_integrate_on_graph(
    graph_structure: GraphStructure,
    observed_grad_per_edge: dict[_E, float],
    confidence_weight_per_edge: Optional[dict[_E, float]],
) -> NDArray[np.float64]:
    variables = get_optimum_deltas(
        graph_structure.sorted_half_edges,
        graph_structure.basis,
        FloatConstraintRightSide(observed_grad_per_edge),
        confidence_weight_per_edge=confidence_weight_per_edge,
    )

    delta_provider = LPDeltaProvider(variables, eps_deviation_from_int=None)

    fct = integrate(
        graph_structure.nodes_predecessors_order,
        FloatUnwGradProvider(observed_grad_per_edge, delta_provider),
    )

    return fct


@dataclass(frozen=True)
class IntegrationResultProvider(ResultProvider):
    result_shape: tuple[int, ...]
    graph_structure: GraphStructure
    dict_per_edge_provider: DictPerEdgeProvider
    weights: Optional[NDArray[np.float64]] = None

    def get_result(self, in_array: NDArray[np.float64]) -> NDArray[np.float64]:
        return solve_integrate_on_graph(
            self.graph_structure,
            self.dict_per_edge_provider.get_grad_dict(in_array),
            self.dict_per_edge_provider.get_weight_dict(self.weights),
        )


def solve_integrate_stack_on_graph(
    graph_structure: GraphStructure,
    dict_per_edge_provider: DictPerEdgeProvider,
    grads_list: list[NDArray[np.float64]],
    weights: Optional[NDArray[np.float64]] = None,
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    result_shape = (len(graph_structure.nodes_predecessors_order[0]),)
    result_provider = IntegrationResultProvider(
        result_shape, graph_structure, dict_per_edge_provider, weights
    )
    integrated_list = get_result_on_stack(result_provider, grads_list, ncpu)
    return integrated_list


def redundant_integration_stack_linprog(
    edges: NDArray[np.int32],
    grads_list: list[NDArray[np.float64]],
    weights: Optional[NDArray[np.float64]] = None,
    integration_ref_id: np.int32 = np.int32(0),
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    dict_per_edge_provider = DictPerEdgeProvider.from_edges_array(edges)
    graph_structure = get_graph_structure_from_edges(
        dict_per_edge_provider.sorted_half_edges, integration_ref_id
    )

    integrated_list = solve_integrate_stack_on_graph(
        graph_structure, dict_per_edge_provider, grads_list, weights, ncpu
    )

    return integrated_list
