import argparse
import glob
from torch.autograd import Variable
from models import DnCNN
from utils import tensor2uint, mkdir, calculate_ssim, batch_PSNR, imsave
import os
import torch
import numpy as np
import cv2

parser = argparse.ArgumentParser(description="DnCNN_Test")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--logdir", type=str, default="S25", help='path of log files')
parser.add_argument("--test_data", type=str, default='CBSD68', help='test on Set12 or Set68')
parser.add_argument("--test_noiseL", type=float, default=25, help='noise level used on test set')
parser.add_argument("--results",type=str,default='results',help='results')
parser.add_argument("--add_noise",type=str,default='add_noise',help='add_noise')
opt = parser.parse_args()

def normalize(data):
    return data/255.

def main():
    # Build model
    print('Loading model ...\n')
    channels = 4
    net = DnCNN(channels, num_of_layers=opt.num_of_layers)

    # Use the model directly on the CPU
    model = net

    state_dict = torch.load(os.path.join('logs', opt.logdir, 'net.pth'), map_location='cpu')
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v

    # model.load_state_dict(torch.load(os.path.join('logs', opt.logdir, 'net.pth'), map_location='cpu'), strict=False)
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # load data info
    print('Loading data info ...\n')
    result_name = str(opt.logdir) + opt.test_data + '_' + str(opt.test_noiseL)
    # Creates output directories for: Denoised results, Noisy test images
    E_path = os.path.join(opt.results, result_name)
    mkdir(E_path)

    noise_name = str(opt.logdir) + opt.test_data + '_' + 'noise' + '_' + str(opt.test_noiseL)
    Noise_path = os.path.join(opt.add_noise, noise_name)
    mkdir(Noise_path)

    files_source = glob.glob(os.path.join('data', opt.test_data, '*.png'))
    files_source.sort()

    psnr_test = 0
    ssim_test = 0
    total_t = 0
    for f in files_source:
        img_name, ext = os.path.splitext(os.path.basename(f))
        Img = cv2.imread(f)
        Img = torch.tensor(Img)
        # Add zero channel (for 3→4 channel conversion)
        zeros_channel = torch.zeros((Img.shape[0], Img.shape[1], 1))
        Img = torch.cat([zeros_channel, Img], dim=2)
        Img = Img.unsqueeze(0)
        Img = Img.permute(0, 3, 1, 2)
        Img = Img.numpy()
        Img = np.float32(normalize(Img))
        ISource = torch.Tensor(Img)

        noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=opt.test_noiseL / 255.)
        INoisy = ISource + noise
        ISource, INoisy = Variable(ISource), Variable(INoisy)

        psnr_noise = batch_PSNR(INoisy, ISource, 1.)
        print("psnr_noise", psnr_noise)

        with torch.no_grad():
            Out = torch.clamp(model(INoisy), 0., 1.)
            # print(Out)

        psnr = batch_PSNR(Out, ISource, 1.)
        psnr_test += psnr
        print("%s PSNR %f" % (f, psnr))

        Out = tensor2uint(Out)
        ISource = tensor2uint(ISource)
        INoisy = tensor2uint(INoisy)

        ssim = calculate_ssim(Out, ISource, border=0)
        ssim_test += ssim
        print("%s SSIM %f" % (f, ssim))

        imsave(INoisy, os.path.join(Noise_path, img_name + ext))
        imsave(Out, os.path.join(E_path, img_name + ext))

    psnr_test /= len(files_source)
    print("\nPSNR on test data %f" % psnr_test)

    ssim_test /= len(files_source)
    print("\nSSIM on test data %f" % ssim_test)



if __name__ == "__main__":
    main()
