import cv2
import os
import argparse
import glob
import torch
from torch.autograd import Variable
from models import DnCNN
from utils import tensor2uint, calculate_ssim, batch_PSNR, imsave
import numpy as np

parser = argparse.ArgumentParser(description="AQNet_test")
parser.add_argument(
    "--num_of_layers", type=int, default=17, help="Number of total layers"
)
parser.add_argument(
    "--logdir",
    type=str,
    default=os.path.join(os.path.dirname(__file__), "logs"),
    help="path of log files",
)
parser.add_argument(
    "--test_data",
    type=str,
    default="D:/phD/finish_code/QMSANet/real_noisy/data/cc/cc",
    help="test on Set12 or Set68",
)
parser.add_argument(
    "--test_noiseL", type=float, default=25, help="noise level used on test set"
)
parser.add_argument("--ori_noise", type=str, default="ori_noise", help="ori_noise")
parser.add_argument("--denoise", type=str, default="denoise", help="denoise")
parser.add_argument("--results", type=str, default="results", help="results")

opt = parser.parse_args()


def normalize(data):
    return data / 255.0


def main():
    # Build model
    print("Loading model ...\n")

    # Save results
    result_name = opt.test_data + "_" + str("realdenoise")
    E_path = os.path.join(result_name)
    # E_path = os.path.join(opt.results, result_name)
    os.mkdir(E_path)

    # nitialize model
    net = DnCNN(channels=4)
    model = net
    # Load model weights
    state_dict = torch.load(
        os.path.join(opt.logdir, "new3_net.pth"), map_location="cpu"
    )
    # Process model weight keys
    new_state_dict = {}
    for k, v in state_dict.items():
        new_key = k.replace("module.", "") if k.startswith("module.") else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict, strict=False)
    # device_ids = [0]
    # model = nn.DataParallel(net, device_ids=device_ids).cuda()
    # model.load_state_dict(torch.load(os.path.join(opt.logdir, 'new3_net.pth')))
    model.eval()
    # load data info
    print("Loading data info ...\n")

    # Get test file list
    # files_source = glob.glob(os.path.join('data', opt.test_data, '*_real.png'))
    files_source = glob.glob(os.path.join(opt.test_data, "*_real.png"))
    files_source.sort()
    # process data
    psnr_test = 0
    ssim_test = 0

    for f in files_source:
        # Process filename
        img_name, ext = os.path.splitext(os.path.basename(f))
        # Load and process noisy image
        Img = cv2.imread(f)
        Img = torch.tensor(Img)
        # Add zero channel for 3→4 channel conversion
        zeros_channel = torch.zeros((Img.shape[0], Img.shape[1], 1))
        Img = torch.cat([zeros_channel, Img], dim=2)
        # Adjust dimension order
        Img = Img.unsqueeze(0)
        Img = Img.permute(0, 3, 1, 2)
        Img = Img.numpy()
        Img = np.float32(normalize(Img))

        ISource = torch.Tensor(Img)

        ISource = Variable(ISource)
        # ISource = ISource.cuda()
        # Load and process clean image
        Img_clean = f[:-9] + "_mean.png"

        Img_clean = cv2.imread(Img_clean)

        Img_clean = torch.tensor(Img_clean)
        # Same 4-channel processing
        zeros_channel = torch.zeros((Img_clean.shape[0], Img_clean.shape[1], 1))
        Img_clean = torch.cat([zeros_channel, Img_clean], dim=2)
        Img_clean = Img_clean.unsqueeze(0)
        Img_clean = Img_clean.permute(0, 3, 1, 2)
        Img_clean = Img_clean.numpy()
        Img_clean = np.float32(normalize(Img_clean))

        Iclean = torch.Tensor(Img_clean)

        Iclean = Variable(Iclean)
        # Iclean = Iclean.cuda()
        psnr_noise = batch_PSNR(ISource, Iclean, 1.0)
        print("psnr_noise", psnr_noise)

        with torch.no_grad():  # this can save much memory
            Out = torch.clamp(model(ISource), 0.0, 1.0)

        psnr = batch_PSNR(Out, Iclean, 1.0)
        psnr_test += psnr

        print("%s PSNR %f" % (f, psnr))

        Out = tensor2uint(Out)
        Iclean = tensor2uint(Iclean)

        ssim = calculate_ssim(Out, Iclean, border=0)
        ssim_test += ssim
        print("%s SSIM %f" % (f, ssim))
        imsave(Out, os.path.join(E_path, img_name + ext))
    psnr_test /= len(files_source)
    print("\nPSNR on test data %f" % psnr_test)
    ssim_test /= len(files_source)
    print("\nSSIM on test data %f" % ssim_test)


if __name__ == "__main__":
    main()
