import glob
import torch.optim as optim
import time
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.nn.modules.loss import _Loss
from models import DnCNN
from utils import batch_PSNR, weights_init_kaiming
import argparse
from tqdm import tqdm
from data_generator_r import DenoisingDataset
import os
import numpy as np
import cv2

parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
parser.add_argument("--batchSize", type=int, default=128, help="Training batch size")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--epochs", type=int, default=70, help="Number of training epochs")
parser.add_argument("--milestone", type=int, default=30, help="When to decay learning rate; should be less than epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--outf", type=str, default="logs", help='path of log files')
parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
parser.add_argument("--noiseL", type=float, default=25, help='noise level; ignored when mode=B')
parser.add_argument("--val_noiseL", type=float, default=25, help='noise level used on validation set')
parser.add_argument("--test_data",type=str,default='cc',help='test on cc')
parser.add_argument('--train_data', default='data/RealTrain', type=str, help='path of train data')

opt = parser.parse_args()


def normalize(data):
   return data/255.

def main():

    # Create save directory with timestamp
    save_dir = opt.outf + 'sigma' + str(opt.noiseL) + '_' + time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime()) + '/'
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    # Load dataset
    print('Loading dataset ...\n')
    aa = "data"

    # Initialize dataset and data loader
    DDataset = DenoisingDataset()
    train_loader = DataLoader(dataset=DDataset, num_workers=0, drop_last=True, batch_size=opt.batchSize, shuffle=True)
    # Initialize model, loss function and optimizer
    net = DnCNN(channels=4, num_of_layers=17)
    criterion = nn.MSELoss(size_average=False)
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()
    best_psnr_val = 0

    # Initialize training variables
    optimizer = optim.RMSprop(model.parameters(), lr=opt.lr) # RMSprop optimizer

    psnr_list = []
    patch_num = len(train_loader)
    # Training loop
    for epoch in range(opt.epochs):
        # Learning rate scheduling
        if epoch < 15:
            current_lr = opt.lr
        elif 15 <= epoch < 50:
            current_lr = opt.lr / 5.
        else:
            current_lr = opt.lr / 25.
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print('learning rate %f' % current_lr)
        # Initialize progress bar
        loop = tqdm(enumerate(train_loader, 0), total=len(train_loader))
        # Batch training
        for i, data in loop:
            model.train()  # Set model to training mode
            noisy,clean = data  # Get noisy and clean image batch
            # noisy = data['input']
            zeros_channel = torch.zeros((noisy.shape[0], 1, noisy.shape[2], noisy.shape[3]))
            noisy = torch.cat([ zeros_channel,noisy], dim=1)

            # Add zero channel (3→4 channel conversion)
            zeros_channel = torch.zeros((clean.shape[0], 1, clean.shape[2], clean.shape[3]))
            clean = torch.cat([ zeros_channel,clean], dim=1)
            # Move data to GPU
            noisy = Variable(noisy.cuda())
            clean = Variable(clean.cuda())

            our_predicted_cleanimage = model(noisy)
            loss = criterion(our_predicted_cleanimage, clean) / (noisy.size()[0] * 2)  # 改到这儿，3.25晚
            # Backward pass and optimize
            optimizer.zero_grad()  # tcw201809112015tcw

            loss.backward()
            optimizer.step()
            # scheduler.step()
            model.eval()
            our_predicted_cleanimage1 = torch.clamp(model(noisy), 0., 1.)
            psnr_train = batch_PSNR(our_predicted_cleanimage1, clean, 1.)
            loop.set_description("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" % (epoch + 1, i + 1, patch_num, loss.item(), psnr_train))
        # Validation phase
        model.eval()  # tcw20180915tcw
        # model_name = 'model' + '_' + str(epoch + 1) + '.pth'  # tcw201809071117tcw
        # torch.save(model.state_dict(), os.path.join(save_dir, model_name))  # tcw201809062210tcw
        psnr = 0
        # files_source = glob.glob(os.path.join('data/cc/', opt.test_data, '*_real.png'))
        files_source = glob.glob(os.path.join(opt.test_data, '*_real.png'))
        files_source.sort()
        psnr_test = 0
        # Process each test image
        for f in files_source:
            Img_test = cv2.imread(f)
            Img_test = torch.tensor(Img_test)
            Img_test = Img_test.unsqueeze(0)    # Add batch dimension
            Img_test = Img_test.permute(0, 3, 1, 2) # Change to [B,C,H,W]

            Img_test = Img_test.numpy()
            zeros_channel = np.zeros((Img_test.shape[0], 1, Img_test.shape[2], Img_test.shape[3]))
            Img_test = np.concatenate([ zeros_channel,Img_test], axis=1)
            Img_test = np.float32(normalize(Img_test))

            # Prepare input tensor
            ISource = torch.Tensor(Img_test)
            ISource = ISource.cuda()
            # Load corresponding clean image
            Img_test_clean = f[:-9] + '_mean.png'
            Img_test_clean = cv2.imread(Img_test_clean)
            Img_test_clean = torch.tensor(Img_test_clean)
            Img_test_clean = Img_test_clean.unsqueeze(0)
            Img_test_clean = Img_test_clean.permute(0, 3, 1, 2)
            Img_test_clean = Img_test_clean.numpy()
            # Add zero channel (3→4 channel conversion)
            zeros_channel = np.zeros(
                (Img_test_clean.shape[0], 1, Img_test_clean.shape[2], Img_test_clean.shape[3]))
            Img_test_clean = np.concatenate([zeros_channel, Img_test_clean], axis=1)

            # Img_test_clean = np.tile(Img_test_clean, (1, 3, 1, 1))
            Img_test_clean = np.float32(normalize(Img_test_clean))
            Img_test_clean = torch.Tensor(Img_test_clean)
            Img_test_clean = Img_test_clean.cuda()

            with torch.no_grad():  # this can save much memory
                Out = torch.clamp(model(ISource), 0., 1.)  # Denoise and clip
            psnr = batch_PSNR(Out, Img_test_clean, 1.)
            psnr_list.append(psnr)
            psnr_test += psnr
            # ssim_test += batch_SSIM(Out, Img_test_clean)
            # print("%s PSNR %f" % (f, psnr))

        psnr_test /= len(files_source)
        # ssim_test /= len(dataset_val)
        print("\n[epoch %d] PSNR_test: %.4f" % (epoch + 1, psnr_test))
        # print("[epoch %d] SSIM_test: %.4f" % (epoch + 1, ssim_test))
        if best_psnr_val < psnr_test:
            best_psnr_val = psnr_test
            torch.save(model.state_dict(), os.path.join(opt.outf, 'real1.pth'))
            print(" Saving best model in ", opt.outf)

    filename = save_dir + 'psnr.txt'
    f = open(filename,'w')
    for line in psnr_list:
        f.write(line+'\n')
    f.close()

if __name__ == "__main__":
    main()
