
# Copyright (c) 2024, Olakunle Abawonse, Gunay Dogan
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


import time
import argparse
import numpy as np
from numpy.linalg import norm
from scipy.ndimage import gaussian_gradient_magnitude
from matplotlib.pyplot import imread, imsave, gray


#---------------------------------------------------------------------
# The following function implements 'Updating average c_1, c_2' part
# of Subsection 3.2, specifically, the equations (22) and (23).
#---------------------------------------------------------------------

def _update_data_term(f, u):
    r"""Update the function r = (f - c1)^2 - (f - c2)^2 using region info u.

    This function updates the image intensity averages :math:`c_1` and
    :math:`c_2` for regions :math:`\Omega_1` and :math:`\Omega_2`.
    It also updates the data term :math:`r` in the segmentation energy

    .. math:: r = (f - c_1)^2 - (f - c_2)^2

    The regions :math:'\Omega_1', :math:'\Omega_2` are defined through
    the function :math:`u` by :math:`\Omega_1 = \{(i,j): u[i,j] >= 0.5 \}`
    and :math:`\Omega_2 = \{(i,j): u[i,j] < 0.5 \}`.

    Parameters
    -----------
    f : array_like
        The array of image pixels.
    u : array_like
        Array storing the values of the region indicator function :math:`u`.

    Returns
    -------
    r : array_like
        Updated values of the function :math:`r = (f - c_1)^2 - (f - c_2)^2`.
    regions : pair of Numpy arrays
        Pair :math:`(\Omega_1,\Omega_2)` of boolean arrays indicating regions.
    averages : pair of float
        The averages :math:`(c_1,c_2)` of image pixel values in the two regions.
    """
    region_1 = (u >= 0.5)
    region_2 = (u <  0.5)

    area_1 = region_1.sum()
    area_2 = region_2.sum()

    c1 = f[ region_1 ].sum() / area_1  if area_1 > 0  else 0.0
    c2 = f[ region_2 ].sum() / area_2  if area_2 > 0  else 0.0

    r = (f - c1)**2 - (f - c2)**2

    averages = (c1, c2)
    regions  = (region_1, region_2)

    return r, regions, averages


#---------------------------------------------------------------------
# The following function implements 'The u subproblem' part of
# Subsection 3.2, specifically, one iteration of the Jacobi update
# used to approximate the solution of the PDE in equation (20).
#---------------------------------------------------------------------

def _jacobi_update_for_u(u, lambdda, gamma, r, d, b):
    r"""Compute update for `u` with one Jacobi sweep.

    The function computes the update for `u`, in the form of one sweep
    of Jacobi iteration applied to the equation

    ..math:: \Delta u = (\lambda/\gamma) r + div(d - b)

    Parameters
    ----------
    u : array_like
    lambdda : float
    gamma : float
    r : array_like
    d : array_like
    b : array_like

    Returns
    -------
    s : array_like
        Array storing the values of a single Jacobi update.
    """
    m,n = u.shape;  m = m-1;  n = n-1
    s = np.empty_like(u)
    R = (-lambdda/gamma) * r

    s[1:m, 1:n] = 0.25 * ( R[1:m, 1:n]
        + u[0:m-1, 1:n]  + u[2:m+1, 1:n]    + u[1:m, 0:n-1]  + u[1:m, 2:n+1]
        - d[0, 1:m, 1:n] + d[0, 0:m-1, 1:n] - d[1, 1:m, 1:n] + d[1, 1:m, 0:n-1]
        + b[0, 1:m, 1:n] - b[0, 0:m-1, 1:n] + b[1, 1:m, 1:n] - b[1, 1:m, 0:n-1] )

    s[0, 1:n] = 1.0/3*( u[1, 1:n] + u[0, 2:n+1] + u[0, 0:n-1] + R[0, 1:n]
                       - d[0,0, 1:n] - d[1,0, 1:n] + d[1,0, 0:n-1]
                       + b[0,0, 1:n] + b[1,0, 1:n] - b[1,0, 0:n-1] )

    s[m, 1:n] = 1.0/3*( u[m-1, 1:n] + u[m, 2:n+1] + u[m, 0:n-1] + R[m, 1:n]
                       + d[0,m-1, 1:n] - d[1,m, 1:n] + d[1,m, 0:n-1]
                       - b[0,m-1, 1:n] + b[1,m, 1:n] - b[1,m, 0:n-1] )

    s[1:m, 0] = 1.0/3*( u[1:m, 1] + u[2:m+1, 0] + u[0:m-1, 0] + R[1:m, 0]
                       - d[0, 1:m, 0] - d[1, 1:m, 0] + d[0, 0:m-1, 0]
                       + b[0, 1:m, 0] + b[1, 1:m, 0] - b[0, 0:m-1, 0] )

    s[1:m, n] = 1.0/3*( u[1:m, n-1] + u[2:m+1, n] + u[0:m-1, n] + R[1:m, n]
                       + d[0, 0:m-1, n] - d[0, 1:m, n] + d[1, 1:m, n-1]
                       - b[0, 0:m-1, n] + b[0, 1:m, n] - b[1, 1:m, n-1] )

    s[0,0] = 0.5*(  u[1,0] + u[0,1] + R[0,0]
                  - d[0,0,0] + b[0,0,0] - d[1,0,0] + b[1,0,0] )

    s[0,n] = 0.5*(  u[1,n] + u[0,n-1] + R[0,n]
                  - d[0,0,n] + b[0,0,n] + d[1,0,n-1] - b[1,0,n-1] )

    s[m,0] = 0.5*(  u[m-1,0] + u[m,1] + R[m,0]
                  + d[0,m-1,0] - b[0,m-1,0] - d[1,m,0] + b[1,m,0] )

    s[m,n] = 0.5*(  u[m-1,n] + u[m,n-1] + R[m,n]
                  + d[0,m-1,n] - b[0,m-1,n] + d[1,m,n-1] - b[1,m,n-1] )
    return s


#-----------------------------------------------------------------------
# The following function implements the update of dual variable 'b'
# given in 5th line inside the while-loop of Algorithm 4, also eqn (25).
#-----------------------------------------------------------------------

def _update_b(b, b_prev, u, d, tau):
    r"""Updates the values `b` using the previous information.

    The function performs a single iteration of

    ..math:: b = b_prev + \tau*(\nabla u - d)

    Parameters
    ----------
    b : array_like
    b_prev : array_like
    u : array_like
    d : array_lke
    tau : float

    Returns
    -------
    b : array_like
        Array storing the updated values of b.
    """
    m, n = u.shape;  m = m-1;  n = n-1
    b[:,:] = b_prev - tau * d
    b[0, 0:m, :] += tau * (u[1:m+1, :] - u[0:m, :])
    b[1, :, 0:n] += tau * (u[:, 1:n+1] - u[:, 0:n])
    return b


#---------------------------------------------------------------------
# The following function implements 'The d subproblem' part of
# Subsection 3.2, specifically, the equation (17).
#---------------------------------------------------------------------

def _subproblem_d(d, u, b, g, gamma):
    r"""Solves the d subproblem at each Bregman iteration for segmentation.

    This function computes the minimizer `d` of the modified segmentation energy

    ..math:: \sum_{ij} g_{ij} |d_{ij}| + 1/2 |d_{ij} - \nabla u_{ij} - b_{ij}|^2

    The minimizer is given by the following formula
    :math:`d = \frac{w}{|w|} \max\{ |w| - f / \gamma, 0 \}`,
    where :math:`w = \nabla u + b`.

    Parameters
    ----------
    d : array_like
    u : array_like
    b : array_like
    g : array_like
    gamma : float

    Returns
    -------
    d : array_like
    """
    epsilon = 1e-6  # to avoid zero values of T=Du in the denominator of q
    m,n = u.shape;  m = m-1;  n = n-1
    T = b.copy()

    T[0, 0:m, :] += (u[1:m+1, :] - u[0:m, :])
    T[1, :, 0:n] += (u[:, 1:n+1] - u[:, 0:n])

    norm_T = np.sqrt( T[0]**2 + T[1]**2 )

    q = (1 / (norm_T + epsilon**2)) * np.maximum( norm_T - g/gamma, 0.0 )
    d[:,:] = T * q
    return d


#---------------------------------------------------------------------
# The following function implements computes the segmentation energy
# given by equation (14) in of Subsection 3.2.
#---------------------------------------------------------------------

def compute_energy(u, f, averages, data_weight, g):
    r"""Compute the value of the two-phase segmentation energy.

    This function computes the value of the two-phase (foreground-background)
    segmentation energy from given region indicator function :math:`u` and
    image :math:`f`.

    Parameters
    ----------
    u : array_like
        2d array of floats storing the region mask.
    f : array_like
        The 2d image array with values between 0 and 1.
    averages : tuple, array_like
        The two averages of the pixel values in the foreground and background.
    data_weight: float
        The data weight in the energy.
    g : array_like
        2d array storing the values of the edge indicator weight used in TV term.

    Returns
    -------
    energy : float
        The value of the segmentation energy.
    """
    m,n = u.shape;  m = m-1;  n = n-1
    c1,c2 = averages

    Du = np.zeros((2,m+1,n+1))
    Du[0, 0:m, :] += u[1:m+1, :] - u[0:m, :]
    Du[1, :, 0:n] += u[:, 1:n+1] - u[:, 0:n]

    data_term = np.sum( u * ((f - c1)**2 + (1-u) * (f - c2)**2) )
    regularization_term = np.sum( g * (np.abs(Du[0]) + np.abs(Du[1])) )
    energy = regularization_term + data_weight * data_term
    energy /= (m*n)  # normalize energy value, divide by # of image pixels.
    return energy


#-----------------------------------------------------------------------
# The following function implements the convergence criterion described
# in Subsection 3.3, specifically, the equation (24).
#-----------------------------------------------------------------------

def check_convergence(u, energy, iter_count, history, n_terms_in_avg,
                      stopping_tol, force_max_iter):
    r"""Check if convergence criteria have been satisfied.

    This function checks and returns whether the convergence criteria have
    been satisfied, based on the current `u`, current `energy`, and the values
    from the previous iterations, stored in `history`.

    Parameters
    ----------
    u : array_like
        2d array storing the region mask with values between 0.0 and 1.0.
    energy : float
        The current energy value.
    iter_count : int
        The current count of iterations.
    history : dict
        Dictionary storing values of various variables from previous iterations.
    n_terms_in_avg : int
        Number of energy values in history used to compute recent energy average.
    stopping_tol : float
        Tolerance to check if relative energy change is small enough for convergence.
    force_max_iter : bool
        If True, then continue iterating until max_iter, disregard criteria.

    Returns
    -------
    converged : bool
        True or False depending on whether the convergence criteria are satisfied.
    energy_change : float
        The relative change in energy compared to the recent moving average.
    """
    m,n = u.shape;  m = m-1;  n = n-1
    converged = False
    energy_change = -np.inf
    init_energy = history['energy'][0] if len(history['energy']) > 0 else energy

    if not force_max_iter and (iter_count >= n_terms_in_avg):
        recent_energy_values = history['energy'][ -n_terms_in_avg: ]
        recent_avg_energy = np.mean( recent_energy_values )
        energy_change = (energy - recent_avg_energy) / init_energy
        converged = (abs( energy_change ) < stopping_tol)

    return converged, energy_change


#---------------------------------------------------------------------
# The following function implements the initialization procedure
# described in Subsection 3.2, given by the equation (26).
#---------------------------------------------------------------------

def _initialize_u_f_g(u0, f0, edge_indicator_scale):
    r"""Initialize region function `u`, image `f`, edge indicator `g`.

    This function gets the region function initialization `u0` of the user,
    the input image `f`, and normalizes the values in these arrays to the
    range [0,1]. If `u0` is None, the `u_init` is set the image values.
    This function also computes the edge indicator function `g` used to
    weight the TV regularization term in the segmentation energy.

    Parameters
    ----------
    u0 : array_like
        The initialization of regions as provided by the user in 2d array.
    f0 : array_like
        The input image given as a 2d array.
    edge_indicator_scale : float
        The parameter used to scale the strength of the edges in the image.

    Returns
    -------
    u_init : array_like
        The normalized initialization for region function `u` with values in [0,1].
    f : array_like
        The image normalized to the range [0,1].
    g : array_like
        The edge indicator function compute from the image `f`.
    """

    if f0.ndim == 3: # if f is an RGB color image, make it grayscale.
        f = 0.2125 * f0[:,:,0] + 0.7154 * f0[:,:,1] + 0.0721 * f0[:,:,2]
    else:
        f = f0

    if u0 is None:
        min_f = f.min();  max_f = f.max()
        u_init = (np.double(f) - min_f) / (max_f - min_f)
    else:
        min_u0 = u0.min();  max_u0 = u0.max()
        u_init = (np.double(u0) - min_u0) / (max_u0 - min_u0)

    df = gaussian_gradient_magnitude(f,2) / edge_indicator_scale
    g = 1.0 / (1.0 + df**2)

    return u_init, f, g


#--------------------------------------------------------------------------
# The following is the main segmentation function implementing Algorithm 4.
#--------------------------------------------------------------------------

def segment_two_phase(f, u0=None, data_weight=0.1, stopping_tol=1e-3,
                      gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                      max_iter=1000, force_max_iter=False, print_iter_progress=0,
                      keep_all_history=False, return_all_details=False):
    r"""Two-phase segmentation of given image f using Bregman iterations.

    This function computes a two-phase segmentation of a given gray-scale
    image `f`. The segmentation is returned as a float-values array `u`,
    with values in the range [0,1]. The locations in the arrray `u` with
    values less than 0.5, i.e. `u` < 0.5 define the first region
    :math:`\Omega_1` identified for segmentation. The remaining locations,
     `u` >= 0.5, define the second region :math:`\Omega_2`.

    This segmentation method is suitable for images whose pixel values can
    be approximated by two average values :math:`c_1` and :math:`c_2`
    for pixels in regions :math:`\Omega_1` and :math:`\Omega_2` respectively.
    The segmentation is obtained by computing the minimizer of :math:`u^*` of
    a segmentation energy, in which the optimization variable is a region
    indicator functions :math:`u(x)`, represented by the return array `u`.

    The segmentation energy consists of a data fidelity term and a regularization
    term, the weighted total variation (TV) :math:`\int g|\nabla u|` of `u`.
    The weight function :math:`g` in TV is the edge indicator function
    :math:`g(x) = 1 / (1 + |\nabla f|^2/s^2)` designed to stop regularization
    or smoothing of the region function :math:`u(x)` across edges in the image
    `f`. The parameter `edge_indicator_scale` is edge strength scale `s` in `g`
    to tune the magnitude of this effect.

    This function computes the minimizer `u` of a modified version of the
    two-phase segmentation energy using Bregman iterations. The parameter
    `gamma` is a model parameter used in the modified energy.

    The user can choose to keep all history and details of the computations,
    such as snapshots of `u` through iterations, and have them returned for
    documentation and diagnosis purposes. But this will increase the memory
    cost of the function significantly, as all copies of `u` and other
    intermediate arrays will be retained and stored in the return variable
    `history`.

    Parameters
    ----------
    f : array_like
        Array storing pixel values of the input image. Image should be grayscale.
    u0 : array_like, optional
        Initialization of `u` for the first iteration. If none, `f` is used
        as initializer.
    data_weight : float, optional
        Weight of the data term in the segmentation energy. Lower value of
        `data_weight` results in smoother region boundaries.
    stopping_tol : float, optional
        Stopping tolerance in the convergence criterion.
    gamma : float, optional
        Model parameter in the modified segmentation energy.
    tau : float, optional
        Step size parameter in the updates to the variable b.
    edge_indicator_scale : float, optional
        Parameter to scale the strength of the edge in the edge indicator
        function used as the weight in total variation.
    max_iter : int, optional
        The maximum number of iterations that the minimization algorithm can take.
    force_max_iter : bool, optional
        If set to True, then the algorithm ignores the convergence criterion,
        and takes `max_iter` number of iterations.
    print_iter_progress : int, optional
        An integer value indicating whether information about the progress
        of iterations should be printed periodically through iterations.
        If it is set to 0, then the information is not printed. If it is set
        to an integer value such as 20, then it is printed every 20 iterations.
    keep_all_history : bool or int, optional
        If set to True, then all values of key variables and arrays through
        the iterations are retained and stored in the return variable `history`.
        If it is set to an integer value, e.g. 5, then the intermediate
        segmentation arrays `u` at iterations 5, 10, 15, ... are stored.
    return_all_details : bool, optional
        If set to True, then all the additional details about the running
        of the algorithm are returned, i.e. the returned values are `u`,
        `averages`,`history`. Otherwise only `u` is returned.

    Returns
    -------
    u : array_like
        A boolean array of the same size as input image f, storing the region
        indicator function.
    averages : tuple, optional
        A pair of float values that are the averages of image pixel values
        in the two regions of the segmentation. This is returned if parameter
        `return_all_details` is True.
    history : dict, optional
        A dictionary storing the values of the intermediate values of key
        variables through the iterations and final information at convergence
        of the minimization algorithm. It is returned if `return_all_details`
        is true. The information stored in `history` includes elapsed time,
        iteration count, and lists of values of energy, energy change, change in
        u for all the iterations. If the parameter `keep_all_history` is set to
        True, then additionally lists of snapshots of intermediate segmentation
        arrays `u` at all iterations are also stored in `history`.
    """
    u_prev, f, g = _initialize_u_f_g( u0, f, edge_indicator_scale )

    m,n = f.shape
    u = np.zeros((m,n));    r = np.zeros((m,n))
    s = np.zeros((m,n));    d = np.zeros((2,m,n))
    b = np.zeros((2,m,n));  b_prev = np.zeros((2,m,n))
    m = m-1; n = n-1

    n_terms_in_avg = 10

    history = { 'u': [], 'd': [], 'b': [], 's': [], 'r':[],
                'energy':[], 'energy change':[], 'u change':[],
                'n terms in avg':n_terms_in_avg }

    save_interval = 1
    if type(keep_all_history) is int: save_interval = keep_all_history

    iter_count = 1
    converged = False
    start_time = time.time()
    energy = np.inf

    while not converged and iter_count <= max_iter:

        r, regions, averages = _update_data_term( f, u_prev )

        s = _jacobi_update_for_u( u_prev, data_weight, gamma, r, d, b_prev )

        u = np.maximum( np.minimum(s, 1.0), 0.0 )

        d = _subproblem_d( d, u, b, g, gamma )

        b = _update_b( b, b_prev, u, d, tau )

        energy = compute_energy( u, f, averages, data_weight, g )

        converged, energy_change = \
                   check_convergence( u, energy, iter_count, history,
                                      n_terms_in_avg, stopping_tol, force_max_iter )

        history['energy'].append( energy )
        history['energy change'].append( energy_change )
        history['u change'].append( norm(u - u_prev) / np.sqrt(m*n) )

        if (print_iter_progress > 0) and (iter_count % print_iter_progress == 0):
            print("k = %d/%d, E = %4.2e, delta_E / stop_tol = %4.2e / %4.2e )" %
                  (iter_count, max_iter, energy, energy_change, stopping_tol) )

        if keep_all_history:
            if (iter_count-1) % save_interval == 0:
                history['u'].append(u.copy())
            else:
                history['u'].append(None)

        iter_count = iter_count + 1

        u_prev, u = u, u_prev
        b_prev, b = b, b_prev


    history['elapsed time'] = time.time() - start_time
    history['iteration count'] = iter_count - 1
    history['final u'] = u

    if print_iter_progress > 0:
        print("Segmentation completed in %d iterations, %4.2fs time." %
              (iter_count-1, history['elapsed time']))

    if not return_all_details:
        return u
    else: # return_all_details
        return u, averages, history


#-----------------------------------------------------------------------------
# The following is the main body, to enable running the segmentation from the
# command line: python segmentation.py --data_weight 0.1 inputfile outputfile
#-----------------------------------------------------------------------------

if __name__ == "__main__":

    arg_parser = argparse.ArgumentParser()

    arg_parser.add_argument( "inputfile",  help="the input image filename, e.g. image.png")

    arg_parser.add_argument( "outputfile", help=("the output filename to save the binary "
                                                 "segmentation as an image, e.g. segmentation.png"))
    arg_parser.add_argument( "-d","--data_weight",type=float, default=0.1,
                      help=("the weight of the data fidelity term, a higher weight "
                            "results in a noisier segmentation, whereas a lower weight "
                            "results in a segmentation with smoother region boundaries, "
                            "default value is 0.1") )
    arg_parser.add_argument( "-s","--stopping_tol",   type=float,  default=0.001,
                      help=("stopping tolerance in the convergence criterion, "
                            "default value is 0.001") )

    arg_parser.add_argument( "-g","--gamma",          type=float,  default=0.1,
                      help=("model parameter in the modified segmentation energy, "
                            "default value is 0.1") )

    arg_parser.add_argument( "-t","--tau",            type=float,  default=0.01,
                      help="step size parameter with default value 0.01" )

    arg_parser.add_argument( "-e","--edge_indicator_scale", type=float, default=0.1,
                      help=("parameter to scale the strength of the edge in the edge "
                            "indicator function (used as the weight in total variation "
                            "term), default value is 0.1") )

    arg_parser.add_argument( "-m","--max_iter",       type=int,  default=1000,
                      help=("the maximum number of iterations that the minimization "
                            "algorithm can take, default value is 1000") )
    arg_parser.add_argument( "-f","--force_max_iter", type=int,  default=0,  choices=[0,1],
                      help=("boolean flag, if set to True, then the maximum number of "
                            "iterations are taken regardless of the stopping criterion, "
                            "default value is False") )
    arg_parser.add_argument( "-p","--print_iter_progress", type=int, default=0,
                      help=("integer value indicating whether information about the progress "
                            "of iterations should be printed periodically through iterations, "
                            "if set to 0, then the information is not printed, if set to an "
                            "integer value, such as 20, then it is printed every 20 iterations; "
                            "its default value is 0") )

    args = arg_parser.parse_args()

    input_file   = args.inputfile
    output_file  = args.outputfile
    data_weight  = args.data_weight
    stopping_tol = args.stopping_tol
    gamma        = args.gamma
    tau          = args.tau
    edge_ind_scl = args.edge_indicator_scale
    max_iter     = args.max_iter
    force_max_iter = args.force_max_iter == 1
    print_iters  = args.print_iter_progress

    f = imread( input_file )

    u = segment_two_phase( f, u0=None, data_weight=data_weight,
                           stopping_tol=stopping_tol,
                           gamma=gamma, tau=tau,
                           edge_indicator_scale=edge_ind_scl,
                           max_iter=max_iter,
                           force_max_iter=force_max_iter,
                           print_iter_progress=print_iters,
                           keep_all_history=False, return_all_details=False)

    segmentation = u > 0.5

    imsave( output_file, segmentation, cmap='gray' )
