
# Copyright (c) 2024, Olakunle Abawonse, Gunay Dogan
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for more
# details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


import matplotlib.pyplot as plt
from matplotlib.pyplot import subplot, imshow, plot, gray, axis, title
from numpy import arange, array
from segmentation import segment_two_phase

FIG_FOLDER = "" # the destination folder where the figures should be saved

def segment_image(filename, data_weight):
    print("\nSegmenting image = %s\n" % filename)

    f = plt.imread( filename )

    u, averages, history = \
       segment_two_phase( f, u0=f, data_weight=data_weight, stopping_tol=1e-4,
                          gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                          max_iter=1000, force_max_iter=False, print_iter_progress=25,
                          keep_all_history=False, return_all_details=True )

    print("\nTotal # iterations: %d,  elapsed time: %5.3f,  final tau: %6.4f.\n" %
          (history['iteration count'], history['elapsed time'], history['tau'][-1]) )

    return f, u, averages, history


def plot_segmentation(f=None, u=None, data_weight=0.1):

    if f is None:
        f = "galaxy.png"
    elif type(f) is str:
        f = plt.imread(f)

    if u is None:
        u = segment_two_phase( f, u0=f, data_weight=data_weight, stopping_tol=1e-4,
                               gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                               max_iter=1000, force_max_iter=False, print_iter_progress=25,
                               keep_all_history=False, return_all_details=False )
    plt.figure()
    plt.subplot(1,2,1);  plt.imshow(f);  plt.gray();  plt.axis('off')
    plt.subplot(1,2,2);  plt.imshow(u);  plt.gray();  plt.axis('off')
    plt.show()
    return f, u


def figure_2_snapshots(image, snapshots, data_weight=0.1):
    print("\n>>>> Producing snapshots from evolution of segmentation for %s image <<<<\n" % image)
    f = plt.imread( image + ".png" )

    u, averages, history = \
       segment_two_phase( f, u0=f, data_weight=data_weight, stopping_tol=1e-4,
                          gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                          max_iter=1000, force_max_iter=False, print_iter_progress=25,
                          keep_all_history=True, return_all_details=True )

    plt.figure();  gray()
    for snapshot in snapshots:
        frame = history['u'][snapshot-1] > 0.5
        filename = FIG_FOLDER + image + "_iter_" + str(snapshot) + ".png"
        plt.imsave( filename, frame )
    plt.close()

    plt.figure(); gray()
    n = len(snapshots)
    for k, snapshot in enumerate(snapshots):
        frame = history['u'][snapshot-1] > 0.5
        subplot(1,n,k+1);  imshow(frame);  gray();  axis('off');  title('iter # %d' % snapshot)
    plt.savefig( FIG_FOLDER + image + "_snapshots.png" )
    plt.close()


def figure_3_energy_plots(image, data_weights):
    print("\n>>>> Plots of energy through iterations for %s image <<<<" % image)
    f = plt.imread( image + ".png" )

    energy_list = []
    for weight in data_weights:
        print("\n  === Data weight = %5.3f ===\n" % weight)
        u,_,history = segment_two_phase( f, u0=f, data_weight=weight, stopping_tol=1e-4,
                            gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                            max_iter=1000, force_max_iter=False, print_iter_progress=25,
                            keep_all_history=False, return_all_details=True )
        energy_list.append( history['energy'] )
        plt.figure()
        plt.plot( history['energy'] )
        plt.title("Energy for %s, $\lambda=%4.2f$" % (image,weight))
        plt.savefig( FIG_FOLDER + image + "_energy_for_lambda_" + str(weight) + ".png" )
        plt.close()

    n = len(data_weights)
    plt.figure()
    for k, weight, energy in zip( range(n), data_weights, energy_list ):
        subplot(1,n,k+1); plot( energy )
        title('Energy for $\lambda = %3.1f$' % weight)
    plt.savefig( FIG_FOLDER + image + "_energy_plots.png" )
    plt.close()


def figure_4_regularization(image, data_weights, labels):
    print("\n>>>> Effects of varying regularization on segmentation of %s image <<<<\n" % image)
    f = plt.imread( image + ".png" )

    u_list = []
    for weight in data_weights:
        u = segment_two_phase( f, u0=f, data_weight=weight, stopping_tol=1e-4,
                               gamma=0.1, tau=0.01, edge_indicator_scale=0.1,
                               max_iter=1000, force_max_iter=False, print_iter_progress=25,
                               keep_all_history=False, return_all_details=False )
        u_list.append(u)

    plt.figure(); gray()
    for u, weight in zip( u_list, data_weights ):
        filename = FIG_FOLDER + image + "_lambda_" + str(weight) + ".png"
        plt.imsave( filename, u )
    plt.close()

    plt.figure(); gray()
    n = len(data_weights)
    for k, u, weight, label in zip( range(n), u_list, data_weights, labels ):
        subplot(1,n,k+1); imshow(u); gray(); axis('off');
        title('%s regularization $\lambda = %3.1f$' % (label, weight))
    plt.savefig( FIG_FOLDER + image + "_regularization.png" )
    plt.close()


######################################################################

if __name__ == "__main__":

    # image, segmentation, averages, history \
    #     = segment_image( "galaxy.png", data_weight=0.1 )

    figure_2_snapshots( "galaxy", [5, 20, 40, 113] )
    figure_2_snapshots( "microstructure", [5, 100, 300, 853] )
    figure_2_snapshots( "pearlite", [3, 100, 200, 434] )

    figure_3_energy_plots( "galaxy", [5.0, 2.0, 1.0] )
    figure_3_energy_plots( "microstructure", [2.5, 0.5, 0.1] )
    figure_3_energy_plots( "pearlite", [5.0, 1.5, 0.1] )

    labels = ["low", "medium", "high"]
    figure_4_regularization( "galaxy", [5.0, 2.0, 1.0], labels )
    figure_4_regularization( "microstructure", [2.5, 0.5, 0.1], labels )
    figure_4_regularization( "pearlite", [5.0, 1.5, 0.1], labels )
