import os
import argparse
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from models import DnCNN
from dataset import prepare_data, Dataset
from utils import batch_PSNR, weights_init_kaiming
from tqdm import tqdm

parser = argparse.ArgumentParser(description="DnCNN")
parser.add_argument("--preprocess", type=bool, default=False, help='run prepare_data or not')
parser.add_argument("--batchSize", type=int, default=64, help="Training batch size")
parser.add_argument("--num_of_layers", type=int, default=17, help="Number of total layers")
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs")  # 这个参数是运行50轮，可以直接在这改
parser.add_argument("--milestone", type=int, default=30, help="When to decay learning rate; should be less than epochs")
parser.add_argument("--lr", type=float, default=1e-3, help="Initial learning rate")
parser.add_argument("--outf", type=str, default="logs", help='path of log files')
parser.add_argument("--mode", type=str, default="S", help='with known noise level (S) or blind training (B)')
parser.add_argument("--noiseL", type=float, default=25, help='noise level; ignored when mode=B')
parser.add_argument("--val_noiseL", type=float, default=25, help='noise level used on validation set')
opt = parser.parse_args()

def main():
    # Load dataset
    learning_rate = 0.0001
    print('Loading dataset ...\n')
    # Initialize training and validation datasets
    dataset_train = Dataset(train=True)  # False
    dataset_val = Dataset(train=False)
    # Create data loader for training set
    loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True)

    print("# of training samples: %d\n" % int(len(dataset_train)))
    net = DnCNN(channels=4, num_of_layers=17)
    # Apply Kaiming initialization to weights
    net.apply(weights_init_kaiming)
    criterion = nn.MSELoss(size_average=False)
    # Move to GPU
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    criterion.cuda()
    # Optimizer
    optimizer = torch.optim.RMSprop(model.parameters(), lr=opt.lr)
    # training
    step = 0
    noiseL_B = [0, 55]
    best_psnr_val = 0

    # Start training loop
    for epoch in range(opt.epochs):
        if epoch < 20:
            current_lr = opt.lr
        elif 20 <= epoch < 50:
            current_lr = opt.lr / 10.
        else:
            current_lr = opt.lr / 100.
        # set learning rate
        for param_group in optimizer.param_groups:
            param_group["lr"] = current_lr
        print('learning rate %f' % current_lr)
        # Initialize progress bar
        loop = tqdm(enumerate(loader_train, 0), total=len(loader_train))
        # train
        for i, data in loop:

            # training step
            model.train()
            model.zero_grad()
            optimizer.zero_grad()
            img_train = data
            # Add zero channel (for 3→4 channel conversion)
            zeros_channel = torch.zeros((img_train.shape[0], 1, img_train.shape[2], img_train.shape[3]))
            img_train = torch.cat([zeros_channel, img_train], dim=1)
            # Generate noise based on training mode
            if opt.mode == 'S':
                # Gaussian noise
                noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=opt.noiseL / 255.)

            if opt.mode == 'B':
                # random noise
                noise = torch.zeros(img_train.size())
                stdN = np.random.uniform(noiseL_B[0], noiseL_B[1], size=noise.size()[0])
                for n in range(noise.size()[0]):
                    sizeN = noise[0, :, :, :].size()
                    noise[n, :, :, :] = torch.FloatTensor(sizeN).normal_(mean=0, std=stdN[n] / 255.)

            # Create noisy image
            imgn_train = img_train + noise
            img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())
            noise = Variable(noise.cuda())
            out_train = model(imgn_train)

            loss = criterion(out_train, img_train) / (imgn_train.size()[0] * 2)

            loss.backward()
            optimizer.step()
            # results
            model.eval()
            out_train = torch.clamp(model(imgn_train), 0., 1.)
            psnr_train = batch_PSNR(out_train, img_train, 1.)

            loop.set_description("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %
                                 (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train))
            step += 1

        # Validation phase at end of epoch
        model.eval()

        # validate
        psnr_val = 0
        ssim_val = 0
        for k in range(len(dataset_val)):
            img_val = torch.unsqueeze(dataset_val[k], 0)
            # Add zero channel (for 3→4 channel conversion)
            zeros_channel = torch.zeros((img_val.shape[0], 1, img_val.shape[2], img_val.shape[3]))

            img_val = torch.cat([zeros_channel, img_val], dim=1)

            noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=opt.val_noiseL / 255.)
            # Synthetic Noise Image
            imgn_val = img_val + noise
            img_val, imgn_val = Variable(img_val.cuda(), volatile=True), Variable(imgn_val.cuda(), volatile=True)
            out_val = torch.clamp(model(imgn_val), 0., 1.)

            psnr_val += batch_PSNR(out_val, img_val, 1.)
            # ssim_val += batch_SSIM(out_val, img_val)

        psnr_val /= len(dataset_val)
        # ssim_val /= len(dataset_val)
        print("\n[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val))
        # print("[epoch %d] SSIM_val: %.4f" % (epoch + 1, ssim_val))
        if best_psnr_val < psnr_val:
            best_psnr_val = psnr_val
            torch.save(model.state_dict(), os.path.join(opt.outf, opt.mode + '_' + str(int(opt.noiseL)) + '_net.pth'))
            print(" Saving best model in ", opt.outf)



if __name__ == "__main__":
    if opt.preprocess:
        if opt.mode == 'S':
            prepare_data(data_path='data', patch_size=50, stride=40, aug_times=1)
        if opt.mode == 'B':
            prepare_data(data_path='data', patch_size=50, stride=40, aug_times=1)
    main()
