import argparse
import glob
from torch.autograd import Variable
from models import DnCNN
from utils import tensor2uint, mkdir, calculate_ssim, batch_PSNR, imsave
import os
import torch
import numpy as np
import cv2

parser = argparse.ArgumentParser(description="DnCNN_Test")
parser.add_argument(
    "--num_of_layers", type=int, default=17, help="Number of total layers"
)
parser.add_argument("--logdir", type=str, default="S25", help="path of log files")
parser.add_argument(
    "--single_image", type=str, default="CBSD68", help="test on Set12 or Set68"
)
parser.add_argument(
    "--test_noiseL", type=float, default=25, help="noise level used on test set"
)
parser.add_argument("--add_noise", action="store_true", help="add noise", default=False)
opt = parser.parse_args()


def normalize(data):
    return data / 255.0


def main():
    # Build model
    print("Loading model ...\n")
    channels = 4
    net = DnCNN(channels, num_of_layers=opt.num_of_layers)
    model = net

    # Load weights and remove 'module.' prefix if needed
    state_dict = torch.load(
        os.path.join(os.path.dirname(__file__), "logs", opt.logdir, "net.pth"),
        map_location="cpu",
    )
    new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # Single Image Path
    image_path = opt.single_image
    img_name, ext = os.path.splitext(os.path.basename(image_path))

    # Image reading and preprocessing
    Img = cv2.imread(image_path)
    Img = torch.tensor(Img)
    zeros_channel = torch.zeros((Img.shape[0], Img.shape[1], 1))
    Img = torch.cat([zeros_channel, Img], dim=2).unsqueeze(0)
    Img = Img.permute(0, 3, 1, 2).numpy()
    Img = np.float32(normalize(Img))
    ISource = torch.Tensor(Img)

    if opt.add_noise:
        noise = torch.FloatTensor(ISource.size()).normal_(
            mean=0, std=opt.test_noiseL / 255.0
        )
        INoisy = ISource + noise
    else:
        INoisy = ISource.clone()

    ISource, INoisy = Variable(ISource), Variable(INoisy)

    with torch.no_grad():
        Out = torch.clamp(model(INoisy), 0.0, 1.0)

    psnr = batch_PSNR(Out, ISource, 1.0)
    ssim = calculate_ssim(tensor2uint(Out), tensor2uint(ISource), border=0)
    with open("results.txt", "w") as f:
        f.write("PSNR: %f, SSIM: %f" % (psnr, ssim))

    imsave(tensor2uint(Out), "output.png")
    if opt.add_noise:
        imsave(tensor2uint(INoisy), "noisy.png")


if __name__ == "__main__":
    main()
