"""
BSD 2-Clause License

Copyright (c) 2024, Centre Borelli
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
from dataclasses import dataclass
from sys import exit
from time import time
from typing import Iterable, Optional

import numpy as np
from delaunay_unwrap.utils import (
    _E,
    ConstraintRightSide,
    DeltaProvider,
    FloatConstraintRightSide,
    ResultProvider,
    UnwGradProvider,
    WambConstraintRightSide,
    WambUnwGradProvider,
    compute_wamb_grad_per_edge,
    delaunay_to_AdjMat,
    get_elem_from_sorted_dict,
    get_elem_from_sorted_grad_dict,
    get_result_on_stack,
    get_sorted_half_edges,
    integrate,
    interval_phase_assertion,
)
from numpy.typing import NDArray
from ortools.graph.python import min_cost_flow
from scipy.sparse.csgraph import breadth_first_order
from scipy.spatial import Delaunay

# This is an implementation of algorithm 5.


@dataclass(frozen=True)
class DualNetwork:
    dual_start_nodes: NDArray[np.int64]
    dual_end_nodes: NDArray[np.int64]
    dual_flowid_per_primal_edge: dict[_E, int]

    @property
    def primal_edge_per_dual_flow_id(self) -> dict[int, _E]:
        return {fid: e for e, fid in self.dual_flowid_per_primal_edge.items()}


# This corresponds to line 3 of algorithm 5.
def get_dual_network(
    simplices: NDArray[np.int32],
    neighbors: NDArray[np.int32],
) -> DualNetwork:
    dual_start = []
    dual_end = []
    primal_edges = []
    n_simplices = len(simplices)
    earth_node_id = n_simplices
    sids = np.arange(n_simplices)

    def add_edges(a: int, b: int, c: int) -> None:
        edges = simplices[:, [a, b]]
        # edges from each triangle to neighbor
        start, end = sids, neighbors[:, c].copy()

        # when neighbor is earth node...
        mask = end == -1
        end[mask] = earth_node_id
        dual_start.append(start)
        dual_end.append(end)
        primal_edges.append(edges)

        # then you also add inverse edges
        dual_start.append(end[mask])
        dual_end.append(start[mask])
        primal_edges.append(edges[mask][:, [1, 0]])

    add_edges(0, 1, 2)
    add_edges(1, 2, 0)
    add_edges(2, 0, 1)

    dual_start_nodes = np.hstack(dual_start)
    dual_end_nodes = np.hstack(dual_end)
    all_primal_edges = np.vstack(primal_edges)

    # This takes most of the time, probably for hashing the tuple
    dual_flowid_per_primal_edge = {
        (e[0], e[1]): i for i, e in enumerate(all_primal_edges)
    }

    return DualNetwork(dual_start_nodes, dual_end_nodes, dual_flowid_per_primal_edge)


@dataclass(frozen=True)
class GraphStructure:
    simplices: NDArray[np.int32]
    sorted_half_edges: Iterable[_E]
    dual_network: DualNetwork
    nodes_predecessors_order: tuple[NDArray[np.int32], NDArray[np.int32]]


# Get essential graph properties
def get_graph_structure(
    pts: NDArray[np.int64], integration_ref_id: np.int32 = np.int32(0)
) -> GraphStructure:
    # Get Delaunay triangulation of points
    # Apparently all triangles are counterclockwise, no need to account for orientation
    D = Delaunay(pts, qhull_options="QJ")

    # Delaunay to adj is very quick, because adj csr structure already stored in D
    Adj = delaunay_to_AdjMat(D)

    sorted_half_edges = get_sorted_half_edges(Adj)

    # Construct dual network start and end simplex indices
    # Associate to each outgoing dual edge a counterclockwise primal edge
    dual_network = get_dual_network(D.simplices, D.neighbors)

    # compute the breadth first tree for integration step
    # We could have the spanning but in practice we only need the breadth first path to integrate
    nodes_predecessors_order = breadth_first_order(
        Adj, integration_ref_id, directed=False
    )

    return GraphStructure(
        D.simplices, sorted_half_edges, dual_network, nodes_predecessors_order
    )


# This corresponds to lines 9-10 of algorithm 5.
def compute_residues(
    simplices: NDArray[np.int32], constraint_right_side: ConstraintRightSide
) -> NDArray[np.int64]:
    residues = np.zeros(len(simplices) + 1, dtype=np.int64)

    for counter, S in enumerate(simplices):
        # Get residue for triangle S, integrating counterclockwise
        # They correspond to the supplies of the dual nodes (See Costantini 98)
        s0, s1, s2 = S
        # 1 - 0
        w01 = constraint_right_side.get_at_edge((s0, s1))
        # 2 - 1
        w12 = constraint_right_side.get_at_edge((s1, s2))
        # 0 - 2
        w20 = constraint_right_side.get_at_edge((s2, s0))

        residue = w01 + w12 + w20
        # all triangles S0,S1,S2 = counterclockwise by Delaunay, no need for orientation
        residues[counter] = residue

    # Let us not forget about the Earth node
    # (the divergences must sum to zero since no flow leaks outside the graph!)
    residues[-1] = -np.sum(residues[0:-1])

    return residues


# This corresponds to line 15 of algorithm 5.
def get_minimum_cost_flows(
    residues: NDArray[np.int64],
    dual_network: DualNetwork,
    capacity: int,
    cost_per_primal_edge: Optional[dict[_E, int]] = None,
) -> NDArray[np.int64]:
    num_edges = len(dual_network.dual_start_nodes)

    capacities = np.full((num_edges,), capacity, np.int64)

    if cost_per_primal_edge is None:
        costs = np.ones_like(dual_network.dual_start_nodes, dtype=np.int64)
    else:
        primal_edge_per_dual_flow_id = dual_network.primal_edge_per_dual_flow_id
        costs = np.array(
            [
                get_elem_from_sorted_dict(
                    cost_per_primal_edge, primal_edge_per_dual_flow_id[fid]
                )
                for fid in range(len(dual_network.dual_start_nodes))
            ],
            dtype=np.int64,
        )

    # We can now define our MCF problem using ORtools
    # Instantiate a SimpleMinCostFlow solver
    smcf = min_cost_flow.SimpleMinCostFlow()
    # Add arcs, capacities and costs in bulk using numpy
    all_arcs = smcf.add_arcs_with_capacity_and_unit_cost(
        dual_network.dual_start_nodes, dual_network.dual_end_nodes, capacities, costs
    )

    # Add supply for each nodes
    smcf.set_nodes_supplies(np.arange(0, len(residues)), residues)

    # We can now call the solver
    # Find the min cost flow
    start = time()
    status = smcf.solve()
    print("Time for solving MCF : ", time() - start)

    # If the problem is infeasible or the result is bad, there is a problem with the input
    # Check that everything went well, else we have a problem (but there shouldn't be any in practice)
    if status != smcf.OPTIMAL:
        print("There was an issue with the min cost flow input.")
        print(f"Status: {status}")
        exit(1)

    flows = smcf.flows(all_arcs)

    return flows


# Provides the optimal deltas to fix the gradient
@dataclass(frozen=True)
class MCFDeltaProvider(DeltaProvider):
    dual_flowid_per_primal_edge: dict[_E, int]
    flows: NDArray[np.int64]

    def get_delta(self, e: _E) -> int:
        """
        Get delta of edge e = (x, y)
        i.e. the edge starts at x and ends at y
        we need to update the wrapped y - x
        by subtracting the flow = (flow out - flow in)

        Equation for edge x, y is:
        f_y - f_x + 2h \delta_xy = f'_y - f'_x + 2h wrapping_amb_xy

        on cycles: \sum \delta_xy = \sum wrapping_amb_xy

        which is \sum flow_out - flow_in = supply

        Thus \delta_xy = flow_xy_out - flow_xy_in

        and f_y = f_x + f'_y - f'_x + 2h (wrapping_amb_xy - \delta_xy)
        """
        x, y = e
        # id of flow out
        flowid_xy = self.dual_flowid_per_primal_edge[(x, y)]
        flow_xy_out = self.flows[flowid_xy]

        # id of flow in
        flowid_yx = self.dual_flowid_per_primal_edge[(y, x)]
        flow_xy_in = self.flows[flowid_yx]

        delta_xy = flow_xy_out - flow_xy_in

        return delta_xy


def solve_unwrap_on_graph(
    graph_structure: GraphStructure, phase: NDArray[np.float64], h: float, c: int
) -> NDArray[np.float64]:
    # compute gradients of wrapped signal and their wrapping ambiguity
    wamb_grad_per_edge = compute_wamb_grad_per_edge(
        graph_structure.sorted_half_edges, phase, h
    )

    #  compute the residues
    residues = compute_residues(
        graph_structure.simplices, WambConstraintRightSide(wamb_grad_per_edge)
    )

    flows = get_minimum_cost_flows(residues, graph_structure.dual_network, c)

    delta_provider = MCFDeltaProvider(
        graph_structure.dual_network.dual_flowid_per_primal_edge, flows
    )

    fct = integrate(
        graph_structure.nodes_predecessors_order,
        WambUnwGradProvider(wamb_grad_per_edge, delta_provider, h),
    )

    return fct


@dataclass(frozen=True)
class MCFUnwrapResultProvider(ResultProvider):
    result_shape: tuple[int, ...]
    graph_structure: GraphStructure
    h: float
    capacity: int

    def get_result(self, in_array: NDArray[np.float64]) -> NDArray[np.float64]:
        return solve_unwrap_on_graph(
            self.graph_structure, in_array, self.h, self.capacity
        )


def solve_unwrap_stack_on_graph(
    graph_structure: GraphStructure,
    phases: list[NDArray[np.float64]],
    h: float,
    capacity: int,
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    result_shape = phases[0].shape
    result_provider = MCFUnwrapResultProvider(
        result_shape, graph_structure, h, capacity
    )
    unwrapped = get_result_on_stack(result_provider, phases, ncpu)
    return unwrapped


# This corresponds to algorithm 5.
def unwrap_delaunay_MCF_ts(
    pts: NDArray[np.int64],
    phases: list[NDArray[np.float64]],
    h: float = np.pi,
    b: Optional[float] = None,
    integration_ref_id: np.int32 = np.int32(0),
    ncpu: int = 1,
) -> list[NDArray[np.float64]]:
    for i, phase in enumerate(phases):
        interval_phase_assertion(phase, h, f"Problem for array of index {i}")
        assert len(phase) == len(pts), f"""Problem for array of index
                                        {i}: phase and pts must have the same length"""

    graph_structure = get_graph_structure(pts, integration_ref_id)

    # Let's set capacities and uniform costs on the dual edges
    if b is not None:
        c = int(np.floor((b + h) / (2 * h)))
    else:
        # (Apparently None isn't supported for capacities, set something very high instead)
        c = 9999

    unwrapped = solve_unwrap_stack_on_graph(graph_structure, phases, h, c, ncpu)

    return unwrapped


def scale_and_round_grad_per_edge(
    grad_per_edge: dict[_E, float], epsilon: float = 1e-3
) -> dict[_E, int]:
    return {e: round(g / epsilon) for e, g in grad_per_edge.items()}


@dataclass(frozen=True)
class EpsIntUnwGradProvider(UnwGradProvider):
    """
    epsilon = 1e-3 for ex.
    f''_xy = round(f'_xy / epsilon)
    f_y - f_x + \delta_xy * epsilon =  f''_xy * epsilon

    constraint on cycles
    \sum \delta_xy = \sum f''xy

    integration on spanning tree
    f_y = f_x + (f''xy - \delta_xy) * epsilon

    """

    epsint_grad_per_edge: dict[_E, int]
    epsilon: float
    delta_provider: MCFDeltaProvider

    def get_at_edge(self, edge: _E) -> float:
        return (
            get_elem_from_sorted_grad_dict(self.epsint_grad_per_edge, edge)
            - self.delta_provider.get_delta(edge)
        ) * self.epsilon


def get_int_cost_per_edge(
    confidence_weight_per_edge: Optional[dict[_E, float]],
) -> Optional[dict[_E, int]]:
    if confidence_weight_per_edge is not None:
        assert all(v <= 1 and v >= 0 for v in confidence_weight_per_edge.values()), """
        confidence weight should be in [0, 1] interval
        """
        w_max = max(confidence_weight_per_edge.values())
        if w_max == 0:
            w_max = 1
            # to avoid division by zero, though this should not happen
            # as all the weights would be zero in this case...
        int_cost_per_edge: Optional[dict[_E, int]] = {
            e: np.ceil(w / w_max * 1e12) for e, w in confidence_weight_per_edge.items()
        }
    else:
        int_cost_per_edge = None
    return int_cost_per_edge


def solve_integrate_on_graph(
    graph_structure: GraphStructure,
    observed_grad_per_edge: dict[_E, float],
    confidence_weight_per_edge: Optional[dict[_E, float]],
    capacity: int = int(1e12),
    epsilon: float = 1e-3,
) -> NDArray[np.float64]:
    epsint_grad_per_edge = scale_and_round_grad_per_edge(
        observed_grad_per_edge, epsilon
    )

    #  compute the residues
    residues = compute_residues(
        graph_structure.simplices, FloatConstraintRightSide(epsint_grad_per_edge)
    )

    flows = get_minimum_cost_flows(
        residues,
        graph_structure.dual_network,
        capacity,
        get_int_cost_per_edge(confidence_weight_per_edge),
    )

    delta_provider = MCFDeltaProvider(
        graph_structure.dual_network.dual_flowid_per_primal_edge, flows
    )

    fct = integrate(
        graph_structure.nodes_predecessors_order,
        EpsIntUnwGradProvider(epsint_grad_per_edge, epsilon, delta_provider),
    )

    return fct
