import math
import torch
import torch.nn as nn
import numpy as np
from skimage.metrics import structural_similarity
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
#from skimage.measure.simple_metrics import compare_psnr
import cv2
import os
from datetime import datetime

# Initialize the weights of the network according to the Kaiming initialization method
# Iterate through each layer of the model and initialize the weights according to the Kaiming method if it is a convolutional (Conv), fully connected (Linear) or batch normalized (BatchNorm) layer
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find(' QuaternionConvolution') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant(m.bias.data, 0.0)

# Calculate the average PSNR value between a batch of images and the corresponding clean (un-noised) images
# Convert the images from a PyTorch tensor to a NumPy array, then compute the PSNR for each pair of images and finally return the average PSNR value
def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

# Calculate the average SSIM value between a batch of images and the corresponding clean (un-noised) images
def batch_SSIM(img, imclean):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    SSIM = 0
    # Img = Img.transpose(0, 2, 3, 1)
    # Iclean = Iclean.transpose(0, 2, 3, 1)
    Img = Img.transpose(1,0,2,3)
    Iclean = Iclean.transpose(1,0,2,3)
    # print("clean",Iclean.shape[0])

    for i in range(Img.shape[0]):
        SSIM += structural_similarity(Iclean[i,:,:,:], Img[i,:,:,:], data_range=1, multichannel=True)
    return (SSIM/Img.shape[0])

# Data enhancement techniques, including flipping and rotating
def data_augmentation(image, mode):
    out = np.transpose(image, (1,2,0))  # Convert an image from (channels, height, width) format to (height, width, channels) format for flipping and rotating operations
    # Depending on the mode parameter, different enhancement methods are selected:
    ## original (mode 0), flip up/down (mode 1), rotate 90/180/270 degrees (modes 2, 4, 6), rotate and then flip up/down (modes 3, 5, 7)

    if mode == 0:
        # original
        out = out
    elif mode == 1:
        # flip up and down
        out = np.flipud(out)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(out)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(out)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(out, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(out, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(out, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(out, k=3)
        out = np.flipud(out)
    # Convert images back to (channel, height, width) format with np.transpose
    return np.transpose(out, (2,0,1))

# SSIM
# --------------------------------------------
def calculate_ssim(img1, img2, border=0):
    '''calculate SSIM
    the same outputs as MATLAB's
    img1, img2: [0, 255]
    '''
    #img1 = img1.squeeze()
    #img2 = img2.squeeze()
    if not img1.shape == img2.shape:
        raise ValueError('Input images must have the same dimensions.')
    h, w = img1.shape[:2]
    # print("lll",img1[border:h-border, border:w-border])
    img1 = img1[border:h-border, border:w-border]
    img2 = img2[border:h-border, border:w-border]

    if img1.ndim == 2:
        return ssim(img1, img2)
    elif img1.ndim == 3:
        if img1.shape[2] == 4:
            ssims = []
            for i in range(4):
                ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
            return np.array(ssims).mean()
        elif img1.shape[2] == 1:
            return ssim(np.squeeze(img1), np.squeeze(img2))
    else:
        raise ValueError('Wrong input image dimensions.')

def ssim(img1, img2, border=0):
    C1 = (0.01 * 255) ** 2
    C2 = (0.03 * 255) ** 2
    # img1 = img1.data.cpu().numpy().astype(np.float64)
    # img2 = img2.data.cpu().numpy().astype(np.float64)
    img1 = img1.astype(np.float64)
    img2 = img2.astype(np.float64)
    kernel = cv2.getGaussianKernel(11, 1.5)
    window = np.outer(kernel, kernel.transpose())

    mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
    mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
    mu1_sq = mu1 ** 2
    mu2_sq = mu2 ** 2
    mu1_mu2 = mu1 * mu2
    sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
    sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
    sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
                                                            (sigma1_sq + sigma2_sq + C2))
    return ssim_map.mean()

# Converts a PyTorch tensor to a uint8 numpy array for image saving
def tensor2uint(img):
    img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()

    if img.ndim == 3:
        img = img.transpose(1,2,0)

    return np.uint8((img*255.0).round())

def get_timestamp():
    return datetime.now().strftime('%y%m%d-%H%M%S')

# Creates a directory if it doesn't exist (single directory version).
def mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)


def mkdirs(paths):
    if isinstance(paths, str):
        mkdir(paths)
    else:
        for path in paths:
            mkdir(path)


def mkdir_and_rename(path):
    if os.path.exists(path):
        new_name = path + '_archived_' + get_timestamp()
        print('Path already exists. Rename it to [{:s}]'.format(new_name))
        os.rename(path, new_name)
    os.makedirs(path)

# Saves image with OpenCV
def imsave(img, img_path):
    img = np.squeeze(img) # Removing singleton dimensions
    if img.ndim == 3:
        # Handling channel order conversion
        img = img[:, :, [1,2,3]]
        # img = img[:, :, [2,1,0]]
        # print("dididi",img.shape)
    cv2.imwrite(img_path, img)

