import cv2
import os
import argparse
import torch
from models import DnCNN
from utils import tensor2uint, calculate_ssim, batch_PSNR, imsave
import numpy as np

parser = argparse.ArgumentParser(description="AQNet_test")
parser.add_argument(
    "--num_of_layers", type=int, default=17, help="Number of total layers"
)
parser.add_argument(
    "--logdir",
    type=str,
    default=os.path.join(os.path.dirname(__file__), "logs"),
    help="path of log files",
)
parser.add_argument(
    "--test_image",
    type=str,
    default="D:/phD/finish_code/QMSANet/real_noisy/data/cc/cc",
    help="test on Set12 or Set68",
)
parser.add_argument(
    "--test_noiseL", type=float, default=25, help="noise level used on test set"
)
parser.add_argument("--ori_noise", type=str, default="ori_noise", help="ori_noise")
parser.add_argument("--denoise", type=str, default="denoise", help="denoise")
parser.add_argument("--results", type=str, default="results", help="results")

opt = parser.parse_args()


def normalize(data):
    return data / 255.0


def main():
    print("Loading model ...\n")

    if not opt.test_image:
        raise ValueError(
            "Specify the path to an image to be tested with the --test_image parameter."
        )

    # Creating the results save path
    result_name = os.path.splitext(os.path.basename(opt.test_image))[0] + "_realdenoise"
    E_path = os.path.join(result_name)
    os.makedirs(E_path, exist_ok=True)

    # Initialization Model
    net = DnCNN(channels=4)
    model = net

    # Loading weights
    state_dict = torch.load(
        os.path.join(opt.logdir, "new3_net.pth"), map_location="cpu"
    )
    new_state_dict = {
        k.replace("module.", "") if k.startswith("module.") else k: v
        for k, v in state_dict.items()
    }
    model.load_state_dict(new_state_dict, strict=False)
    model.eval()

    # Read and process noisy images
    Img = cv2.imread(opt.test_image)
    Img = torch.tensor(Img)
    zeros_channel = torch.zeros((Img.shape[0], Img.shape[1], 1))
    Img = torch.cat([zeros_channel, Img], dim=2)
    Img = Img.unsqueeze(0).permute(0, 3, 1, 2).numpy()
    Img = np.float32(normalize(Img))
    ISource = torch.Tensor(Img)

    # Get the corresponding clean image
    clean_img_path = opt.test_image.replace("_real.png", "_mean.png")
    if not os.path.exists(clean_img_path):
        raise FileNotFoundError(
            f"Can't find a clean image to correspond to: {clean_img_path}"
        )

    Img_clean = cv2.imread(clean_img_path)
    Img_clean = torch.tensor(Img_clean)
    zeros_channel = torch.zeros((Img_clean.shape[0], Img_clean.shape[1], 1))
    Img_clean = torch.cat([zeros_channel, Img_clean], dim=2)
    Img_clean = Img_clean.unsqueeze(0).permute(0, 3, 1, 2).numpy()
    Img_clean = np.float32(normalize(Img_clean))
    Iclean = torch.Tensor(Img_clean)

    with torch.no_grad():
        Out = torch.clamp(model(ISource), 0.0, 1.0)

    psnr = batch_PSNR(Out, Iclean, 1.0)
    ssim = calculate_ssim(tensor2uint(Out), tensor2uint(Iclean), border=0)

    print("PSNR:", psnr)
    print("SSIM:", ssim)

    # save image
    # imsave(tensor2uint(Out), os.path.join(E_path, os.path.basename(opt.test_image)))
    imsave(tensor2uint(Out), "output.png")


if __name__ == "__main__":
    main()
