#!/usr/bin/env/python3
''' <This program runs the IPOL demo "Thin-Plate Splines on the Sphere.>
    Copyright (C) 2022 Max Dunitz

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <https://www.gnu.org/licenses/>.
'''
import warnings, os, argparse
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.special import  xlogy
import matplotlib.image as mpimg
from scipy.special import zeta

## LINEAR ALGEBRA HELPER

def pinv_and_cond(a, rcond=1e-15):
    '''function that returns the condition number of a matrix in addition to its Moore-Penrose pseudoinverse.
    a - the matrix
    rcond - singular value cutoff, exactly as in np.linalg.pinv
    returns np.linalg.pinv(a, rcond=1e-15) as well as np.linalg.cond(a)
    '''

    u, s, vt = np.linalg.svd(a)
    cutoff = np.asarray(rcond)[..., np.newaxis] * np.amax(s, axis=-1, keepdims=True)
    large = s > cutoff
    s = np.divide(1, s, where=large, out=s)
    s[~large] = 0
    pinv = vt.T @ (s[..., np.newaxis] * u.T)
    return pinv, s[...,0]/s[...,-1] 

## KERNEL METHODS

def natural_cubic(x1sample, x2sample, l=0, training=True):
    '''sets up the linear system in equation (37) for the natural cubic spline on [-1,1]
       as in algorithm 4, but for index set X=[-1,1]
    '''
    x1s, x2s = np.meshgrid(x1sample, x2sample, indexing='ij')
    
    # compute the kernel k^1 of H_1 (defined in section 2.6.1) elementwise (order m=2) on X=[-1,1]
    # the kernel is defined in equation 43 and the explicit form is given in remark 2.56
    ms = np.minimum(x1s, x2s)
    K1 = x1s*x2s*(ms+1) - 0.5*(x1s+x2s)*(ms**2-1) + 1/3.0*(ms**3+1)
    if training: # when training (learning alphas), we can include a smoothing parameter l
        assert len(x1sample) == len(x2sample)
        M_ = K1 + l*np.eye(len(x1sample))
    else:
        M_ = K1
    # complement the Gram matrix of representers of evaluation with nullspace functions ("T" in (37))
    # the nullspace has orthonormal basis {1, 1+x} (see remark 2.56)
    T = np.zeros((len(x1sample),2))
    T[:,0] = 1
    T[:,1] = 1+x1sample
    Tt = np.zeros((2,len(x2sample)))
    Tt[0,:] = 1
    Tt[1,:] = 1+x2sample
    return np.block([[M_, T], [Tt, np.zeros((2,2))]])

def tps(x1sample, y1sample, x2sample, y2sample, l=0, training=True):
    '''
    sets up the linear system in equation (37) for the thin-plate splines
    implements only thin-plate splines of order 2 on R^2 (see algorithm 7, with d=2, m=2)
    this is the "classic" TPS energy whose null space consists only of the affine functions span{1, x, y}
    the green's function is n^2 log n where n is np.linalg.norm(x1sample - x2sample[:,np.newindex], axis=1)
    '''
    # compute the matrix of "representers of evaluation"
    x1s, x2s = np.meshgrid(x1sample, x2sample, indexing='ij')
    diffx = x2s-x1s
    y1s, y2s = np.meshgrid(y1sample, y2sample, indexing='ij')
    diffy = y2s-y1s
    r = diffx**2 + diffy**2
    if training: # when training (learning alphas), we can include a smoothing parameter l 
        assert len(x1sample) == len(x2sample)
        M = xlogy(r,r) + l*np.eye(len(x1sample))
    else:
        M = xlogy(r,r)
    # the nullspace is spanned by 1,x,y; complement M with the "T" of (37)
    N1 = len(x1sample)
    N2 = len(x2sample)
    T = np.hstack((np.ones((N1,1)),x1sample.reshape((N1,1)),y1sample.reshape((N1,1))))
    Tt = np.hstack((np.ones((N2,1)),x2sample.reshape((N2,1)),y2sample.reshape((N2,1)))).T
    return np.block([[M, T], [Tt, np.zeros((3,3))]])

def get_K_tps_sphere(theta1sample, phi1sample, theta2sample, phi2sample, plog2, plog3, order=2, l=1, training=True):
    # null space of Laplace-Beltrami operator on the sphere is span{1} 
    '''
    computes the matrix on the left-hand side of the linear system (37) for the spherical thin-plate splines
    only works for sphere S^2 in R^3 (see algorithm 8), kernel defined on "Fourier side"; Wendelberger
    computed the kernel sums k_{3,2} and k_{3,3} (see section 3.1)
    '''
    # compute cosine of angle between each pair of points, i.e. x^ty for x,y in R^3 with ||x||_{R^3} = ||y||_{R^3} = 1
    theta1s, theta2s = np.meshgrid(theta1sample, theta2sample, indexing='ij')
    phi1s, phi2s = np.meshgrid(phi1sample, phi2sample, indexing='ij')
    cosgamma = clip(np.cos(theta1s)*np.cos(theta2s) + np.sin(theta1s)*np.sin(theta2s)*np.cos(phi2s-phi1s), -1, 1) # elementwise cosine of angle between (theta1, phi1) and (theta2, phi2)
    u = (1-cosgamma)/2.0 # argument to polylog
    omu = 1-u # other argument to polylog 
    idxs = get_idx(u) # indices in the plog3 or plog2 arrays of u
    idxs_omu = get_idx(omu) # indices in the plog3 or plog2 arrays of 1-u 
    # compute kernel
    if order == 3:
        K = (4*np.pi)**(-1)*np.where(np.abs(u)<1e-15, 2*(zeta(3)-1), -2*plog3[idxs] - plog2[idxs_omu] + np.log(u)*plog2[idxs] + 2*zeta(3) + np.pi**2/6.0 - 2) # need to handle cosgamma = 1 separately; see section 3.1
        if training:
            assert len(theta1sample) == len(theta2sample)
            np.fill_diagonal(K,(2*np.pi)**(-1)*(zeta(3)-1)+l) # regularize
    elif order == 2:
        K = (4*np.pi)**(-1)*(1 - np.pi**2/6 + plog2[idxs_omu])
        if training:
            assert len(theta1sample) == len(theta2sample)
            np.fill_diagonal(K,(4*np.pi)**(-1)+l) # regularize
    # since nullspace is {1}, the T matrix is a vector of 1s
    N1 = len(theta1sample)
    N2 = len(theta2sample)
    T = np.ones((N1,1))
    Tt = np.ones((1,N2))
    return np.block([[K, T], [Tt, np.zeros((1,1))]])

## OTHER HELPER FUNCTIONS

def clip(x, lower=0, upper=1):
    '''function that places data within [lower, upper]
       used to counteract rounding error that could screw up the domain of arccos
    '''
    x[x<lower]=lower
    x[x>upper]=upper
    return x

def display_to_math(lat, lon):
    '''
    maps latitude (lat) in [-90, 90] (degrees) to colatitude (colat) in [0, pi] (radians)
    maps longitude (lon) in [-180,180] (degrees) to longitude/azimuthal angle (lon) in [0, 2*pi] (radians)
    '''
    colat = np.pi/2-lat*np.pi/180
    lon = np.pi+lon*np.pi/180
    return colat, lon

def math_to_display(colat,lon):
    '''
    maps a colatitude (colat) in [0, pi] (radians) to latitude (lat) in [-90, 90] (degrees)
    maps a longitude/azimuthal angle (lon) in [0, 2*pi] (radians) to longtidue (lon) in [-180, 180] (degrees)
    '''
    lat = 90-colat*180/np.pi
    lon = -180+lon*180/np.pi
    return lat, lon

def get_idx(cosgamma, step=1e-5):
    '''get the index in our plog2 and plog3 arrays of the polylog evaluations of numbers (cosgamma) in [-1,1]'''
    return np.round((cosgamma - (-1))/step).astype('int')

if __name__ == '__main__':
    
    # PARAMETERS GOVERNING SIMULATION CAPACITY     
    MAX_PTS = 250
    OUTPUT_GRID_LAT = 200
    OUTPUT_GRID_LON = 400

    # MATPLOTLIB params
    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)

    # LOAD INPUT
    files = [f for f in os.listdir('.') if os.path.isfile(f)]
    files_csv = [f for f in os.listdir('.') if os.path.isfile(f) and f[-4:] == '.csv']
    if len(files_csv) > 0: # use csv if available
        df = pd.read_csv(files_csv[0])
        X = np.zeros((len(df), 3))
        X[:,0] = df['latitudes']
        X[:,1] = df['longitudes']
        X[:,2] = df['observations']
        n = len(X)
        rows = np.arange(n)
        if n > MAX_PTS:
            np.random.shuffle(rows)
            rows = rows[:MAX_PTS]
            n = MAX_PTS
    else: # else use image
        files_png = [f for f in os.listdir('.') if os.path.isfile(f) and f[-4:] == '.png']
        if len(files_png) == 0:
            with open("demo_failure.txt", "w") as f:
                f.write("No input selected! Please re-run after selecting an input.")
            exit(0)
        img = mpimg.imread(files_png[0])
        if len(img.shape) == 3:
            img = np.mean(img, axis=2)
        M, N = img.shape
        n = np.minimum(MAX_PTS, M*N)
        # to get random points on a sphere, get samples from a spherical Gaussian in R^3 and project onto sphere, then find which pixel bins of the planisphere image these fall in
        pts_ = np.random.normal(size=(int(1.1*MAX_PTS),3))
        pts_ = pts_/np.tile(np.linalg.norm(pts_, axis=1).reshape((-1,1)), (1,3))
        lons_pts_, lats_pts_ = np.arctan2(pts_[:,1], pts_[:,0])*180/np.pi, 90-np.arccos(pts_[:,2]/np.sqrt(pts_[:,0]**2+pts_[:,1]**2+pts_[:,2]**2))*180/np.pi
        valid = np.logical_not(np.logical_or(np.isnan(lats_pts_), np.isnan(lons_pts_)))
        pixel_lats = np.linspace(-90,90,M+1)[:-1]
        pixel_lons = np.linspace(-180,180,N+1)[:-1]
        lats_idxs = np.searchsorted(pixel_lats[1:], lats_pts_[valid])
        lons_idxs = np.searchsorted(pixel_lons[1:], lons_pts_[valid])
        vals = img[(M-1)-lats_idxs, lons_idxs]
        X = np.hstack((pixel_lats[lats_idxs].reshape((-1,1)) + 90/M, pixel_lons[lons_idxs].reshape((-1,1)) + 180/N, vals.reshape((-1,1))))
        rows = np.arange(n)
    lat_train_disp, lon_train_disp = X[rows,0], X[rows,1]
    lat_train_math, lon_train_math = display_to_math(lat_train_disp, lon_train_disp)
    pts_display = X[rows,:2]
    pts_math = np.hstack([lat_train_math.reshape((n,1)), lon_train_math.reshape((n,1))])
    T = X[rows,2]

    # PARSE ARGUMENTS

    parser = argparse.ArgumentParser(description="Reads  the file 'in.csv', which is a three-column data set with latitude values (-90-90 degrees) in the first column, longitude (-180 to 180 degrees) in the second, and observed data in the third. Fits a thin-plate spherical spline model. Interpolation is exact when penalty is 0, otherwise the model minimizes the empirical risk with a smoothness penalty (penalty > 0).")
    parser.add_argument('-p','--penalty', help='Penalty value (float), spherical spline', required=True)
    parser.add_argument('-pc', '--penalty_cubic', help='Penalty value (float), natural cubic spline', required=True)
    parser.add_argument('-ptps', '--penalty_tps', help='Penalty value (float), Euclidean thin-plate spline', required=True)
    parser.add_argument('-o', '--order', help='Spline order (either 2 or 3)', required=True)
    parser.add_argument('-mila', '--minimum_latitude', help='Minimum latitude of output planisphere image', required=True)
    parser.add_argument('-mala', '--maximum_latitude', help='Maximum latitude of output planisphere image', required=True)
    parser.add_argument('-milo', '--minimum_longitude', help='Minimum longitude of output planisphere image', required=True)
    parser.add_argument('-malo', '--maximum_longitude', help='Maximum longitude of output planisphere image', required=True)
    parser.add_argument('-loc', '--location', help="Location of demoextras", default='..', required=False)
    args = vars(parser.parse_args())
    l = 10**(float(args['penalty']))*n # multiply lambda by n to get regularization parameter l=lambda*n
    l_cubic = 10**(float(args['penalty_cubic']))*n
    l_tps = 10**(float(args['penalty_tps']))*n
    order = int(args['order'])
    min_lon = float(args['minimum_longitude'])
    max_lon = float(args['maximum_longitude'])
    min_lat = float(args['minimum_latitude'])
    max_lat = float(args['maximum_latitude'])
    loc = args['location']

    # LOAD LOOKUP TABLES
    plog3 = np.load(os.path.join(loc, "plog3.npy"))
    plog2 = np.load(os.path.join(loc, "plog2.npy"))

    # PARAMETERS FOR PLOTTING
    dlon = np.abs(max_lon-min_lon)
    dlat = np.abs(max_lat-min_lat)
    marg = 0.01
    ext = [min_lon-dlon*marg, max_lon+dlon*marg, min_lat-dlat*marg, max_lat+dlat*marg]
    
    # CREATE OUTPUT GRID
    M, N = OUTPUT_GRID_LAT, OUTPUT_GRID_LON
    nrows, ncols = (M, N)
    lat_display, lon_display = np.meshgrid(np.linspace(min_lat, max_lat, nrows), np.linspace(min_lon, max_lon, ncols), indexing='ij')
    lat_math, lon_math = display_to_math(lat_display, lon_display)

    latlong_te_math = np.zeros((M*N,2))
    latlong_te_disp = np.zeros((M*N,2))
    latlong_te_math[:,0] = lat_math.flatten()
    latlong_te_math[:,1] = lon_math.flatten()
    latlong_te_disp[:,0] = lat_display.flatten()
    latlong_te_disp[:,1] = lon_display.flatten()

    ## SCATTER PLOT OF DATA, NATURAL CUBIC SPLINE INTERPOLATION, AND MEAN COMPUTED FROM NATURAL CUBIC SPLINE
    xs_plot = np.linspace(-1,1,1000)
    plt.clf()
    ax = plt.axes()
    sin_train_lat = np.sin(np.pi/2-lat_train_math) # plot data in sin theta
    ax.scatter(sin_train_lat, T, marker='.', label=r'Observations')
    # solve equation (37) for the natural cubic splines to recover alphas 
    K_lin = natural_cubic(sin_train_lat, sin_train_lat, l=l_cubic, training=True)
    alphas = np.linalg.pinv(K_lin) @ np.array(list(T) + [0,0])
    c = alphas[-2] + alphas[-1] 
    cx = alphas[-1]
    # evaluate the interpolant on xs_plot
    evals = (natural_cubic(xs_plot, sin_train_lat, training=False)@alphas)[:-2]
    # compute the mean value (see equation 62 in section 5.2)
    mean = (alphas[-1] + alphas[-2]) + 0.5*(sin_train_lat**4/24 - sin_train_lat**3/6 + sin_train_lat**2/4 + 7*sin_train_lat/6 + 17/24)@alphas[:-2]
    print("AREA:", mean, np.mean(evals))
    ax.plot(xs_plot, evals, label=r"Natural Cubic Spline, $\lambda=$"+f"{l_cubic:.3e}")
    ax.set_xlabel('sin(latitude)')
    ax.set_ylabel(r'Observation Values')
    textstr = r'mean$=%.3f$' % (mean, )
    txt = ax.text(0.04, 0.96, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
    txt.set_bbox({'alpha': 0.05})
    plt.title(r"Natural Cubic Spline Fit (for Isotropic Data)")
    plt.legend(loc=1, fancybox=True, framealpha=0.5)
    plt.savefig('naturalcubicspline.png')

    # PLANISPHERE PLOT OF SCATTERED DATA WITH SAMPLE MEAN
    marg = 0.01 # plotting param
    lower_, upper_ = np.min(T)-marg*np.abs(np.min(T)), np.max(T)+marg*np.abs(np.max(T)) # plottling limits
    fig, ax = plt.subplots()
    plt.scatter(pts_display[:,1], pts_display[:,0], c=T, cmap='coolwarm', vmin=lower_, vmax=upper_, s=10)
    ax.set_xlabel(r"Longitude (degrees)")
    ax.set_ylabel(r"Latitude (degrees)")
    plt.title(f"Scattered Data")
    txt = ax.text(0.04, 0.96, f"Sample mean: {np.mean(T):.3f}", transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
    txt.set_bbox({'alpha': 0.05})
    plt.colorbar()
    plt.savefig("data.png", bbox_inches='tight')
    plt.close() 

    # THIN-PLATE SPLINE ON THE SPHERE
    K_sphere = get_K_tps_sphere(pts_math[:,0], pts_math[:,1], pts_math[:,0], pts_math[:,1], plog2=plog2, plog3=plog3, order=order, l=l, training=True)
    K_sphere_inv, cond = pinv_and_cond(K_sphere)
    if cond > 1e10:
        warnings.warn("TPS spherical badly conditioned, consider increasing the smoothing penalty") 
    alphas = K_sphere_inv @ np.array(list(T) + [0])
    predictions_sphere = get_K_tps_sphere(latlong_te_math[:,0], latlong_te_math[:,1], pts_math[:,0], pts_math[:,1], plog2=plog2, plog3=plog3, order=order, training=False) @ alphas
    predictions_sphere = predictions_sphere[:-1].flatten()
    
    # plot interpolating surface with spherical mean
    fig, ax = plt.subplots()
    im = plt.imshow(predictions_sphere.reshape((M,N))[::-1,:], cmap='coolwarm', vmin=lower_, vmax=upper_, extent=ext) 
    ax.set_xlabel(r"Longitude (degrees)")
    ax.set_ylabel(r"Latitude (degrees)")
    ax.set_title(r"Interpolating Surface, TPS (spherical, $\lambda=$"+f"{l:.3e})")
    textstr = r'spherical mean$=%.3f$' % (alphas[-1], )
    txt = ax.text(0.04, 0.96, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
    txt.set_bbox({'alpha': 0.05})
    cax = fig.add_axes([ax.get_position().x1+0.03,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax)
    plt.savefig("interp_sphere.png", bbox_inches='tight')
    plt.close()

    # see how well the sphere recovers the data it was trained on 

    recovered_sphere = get_K_tps_sphere(pts_math[:,0], pts_math[:,1], pts_math[:,0], pts_math[:,1], plog2=plog2, plog3=plog3, order=order, training=False) @ alphas
    recovered_sphere = recovered_sphere[:-1]
    
    plt.figure()
    plt.scatter(pts_display[:,1], pts_display[:,0], c=np.abs(T-recovered_sphere), cmap='coolwarm',  s=10)
    plt.xlabel(r"Longitude (degrees)")
    plt.ylabel(r"Latitude (degrees)")
    plt.title(r"Reconstruction Error, TPS (spherical, $\lambda=$"+f"{l:.3e})")
    plt.colorbar()
    plt.savefig("reconstruction_err_sphere.png", bbox_inches="tight")
    plt.close()

    ## PLANAR THIN-PLATE SPLINES

    K_plane = tps(pts_math[:,0], pts_math[:,1], pts_math[:,0], pts_math[:,1], l=l_tps, training=True)
    K_plane_inv, cond = pinv_and_cond(K_plane)
    if cond > 1e10:
        warnings.warn("TPS planar badly conditioned, consider increasing the smoothing penalty")
    alphas = np.linalg.pinv(K_plane) @ np.array(list(T) + [0,0,0])
    predictions_plane = tps(latlong_te_math[:,0], latlong_te_math[:,1], pts_math[:,0], pts_math[:,1], training=False) @ alphas
    predictions_plane = predictions_plane[:-3].flatten()

    # compute area under curve using equation 61 in section 5.1
    lons = np.linspace(0, 6.283, 6284)
    lats = np.linspace(0, 3.142, 3143)
    lats_, lons_ = np.meshgrid(lats, lons, indexing='ij')
    int_vals = np.load(os.path.join(loc,'int_vals.npy'))
    coords = np.around(pts_math*1000).astype('int')
    planar_mean = alphas[-3] + np.pi/2*alphas[-2] + np.pi*alphas[-1] + alphas[:-3]@int_vals[coords[:,0], coords[:,1]] # equation 61

    # plot interpolating surface with planar mean
    fig, ax = plt.subplots()
    im = plt.imshow(predictions_plane.reshape((M,N))[::-1,:], cmap='coolwarm', vmin=lower_, vmax=upper_, extent=ext) 
    plt.xlabel(r"Longitude (degrees)")
    plt.ylabel(r"Latitude (degrees)")
    plt.title(r"Interpolating Surface, TPS (planar, $\lambda=$"+f"{l_tps:.3e})")
    textstr = r'mean$=%.3f$' % (planar_mean, )
    txt = ax.text(0.04, 0.96, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
    txt.set_bbox({'alpha':0.05})
    cax = fig.add_axes([ax.get_position().x1+0.03,ax.get_position().y0,0.02,ax.get_position().height])
    plt.colorbar(im, cax=cax) # Similar to fig.colorbar(im, cax = cax)
    plt.savefig("interp_plane.png", bbox_inches='tight')
    plt.close()

    recovered_plane = tps(pts_math[:,0], pts_math[:,1], pts_math[:,0], pts_math[:,1]) @ alphas
    recovered_plane = recovered_plane[:-3].flatten()
    
    # see how well this interpolating surface recovers the data it was trained on
    plt.figure()
    plt.scatter(pts_display[:,1], pts_display[:,0], c=np.abs(T-recovered_plane), cmap='coolwarm', s=10)
    plt.xlabel(r"Longitude (degrees)")
    plt.ylabel(r"Latitude (degrees)")
    plt.title(r"Reconstruction Error, TPS (planar, $\lambda=$"+f"{l_tps:.3e})")
    plt.colorbar()
    plt.savefig("reconstruction_err_plane.png", bbox_inches="tight")
    plt.close()

    # compare spherical and planar interpolating surfaces
    mask = np.logical_and( np.logical_and(pts_display[:,0] >= min_lat, pts_display[:,0] <= max_lat), np.logical_and(pts_display[:,1] >= min_lon, pts_display[:,1] <= max_lon) )

    err = np.abs(predictions_sphere-predictions_plane)
    print("err", np.mean(err), np.min(err), np.max(err), np.median(err))
    plt.figure()
    plt.scatter(pts_display[mask,1], pts_display[mask,0], color='red', marker='x', cmap='coolwarm', label='observation locations')
    plt.scatter(latlong_te_disp[:,1], latlong_te_disp[:,0], c=err, alpha=0.4, s=0.1, vmin=np.min(err), vmax=np.max(err), cmap='coolwarm')
    plt.colorbar()
    plt.legend(fancybox=True, framealpha=0.5)
    plt.ylabel(r"Latitude (degrees)")
    plt.xlabel(r"Longitude (degrees)")
    plt.title("Absolute Difference: Spherical vs. Planar TPS")
    plt.savefig("diff_sphere_plane.png", bbox_inches='tight')
    plt.close()
