import torch
import torch.nn as nn
import numpy as np
from qcnn_convolutional_layer import QuaternionConvolution
from SimAM import SimAM

# Defined quaternion denoising network structure
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=15):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        groups =1
        layers = []
        kernel_size1 = 1
        self.conv1_1 = nn.Sequential(QuaternionConvolution(in_channels=channels,out_channels=features,kernel_size=kernel_size,stride=1,padding=padding),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_2 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=2,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_3 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_4 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_5 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=2,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_6 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_7 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=padding),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_8 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_9 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=2,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_10 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_11 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_12 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=2,dilation=2),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_13 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=padding),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_14 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=padding),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_15 = nn.Sequential(QuaternionConvolution(in_channels=features,out_channels=features,kernel_size=kernel_size,stride=1,padding=1),nn.BatchNorm2d(features),nn.ReLU(inplace=True))
        self.conv1_16 = QuaternionConvolution(in_channels=features,out_channels=4,kernel_size=kernel_size,stride=1,padding=1)
        self.conv3 = QuaternionConvolution(in_channels=8,out_channels=4,kernel_size=1,stride=1,padding=0)
        self.ReLU = nn.ReLU(inplace=True)
        self.Tanh= nn.Tanh()
        self.sigmoid = nn.Sigmoid()
        self.simam = SimAM(channels)


        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
            if isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(0, (2 / (9.0 * 64)) ** 0.5)
                clip_b = 0.025
                w = m.weight.data.shape[0]
                for j in range(w):
                    if m.weight.data[j] >= 0 and m.weight.data[j] < clip_b:
                        m.weight.data[j] = clip_b
                    elif m.weight.data[j] > -clip_b and m.weight.data[j] < 0:
                        m.weight.data[j] = -clip_b
                m.running_var.fill_(0.01)

        def _make_layers(self, block, features, kernel_size, num_of_layers, padding=1, groups=1, bias=False):
            layers = []
            for _ in range(num_of_layers):
                layers.append(block(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding))
            return nn.Sequential(*layers)

    def forward(self, x):
        # QMSB module
        x1 = self.conv1_1(x)
        x1 = self.conv1_2(x1)
        x1 = self.conv1_3(x1)
        x1 = self.conv1_4(x1)
        x1 = self.conv1_5(x1)
        x1 = self.conv1_6(x1)
        x1 = self.conv1_7(x1)
        x1t = self.conv1_8(x1)
        x1 = self.conv1_9(x1t)
        x1 = self.conv1_10(x1)
        x1 = self.conv1_11(x1)
        x1 = self.conv1_12(x1)
        # QSEB module
        x1 = self.conv1_13(x1)
        x1 = self.conv1_14(x1)
        x1 = self.conv1_15(x1)
        x1 = self.conv1_16(x1)
        out = torch.cat([x, x1],1)
        out= self.Tanh(out)
        # LAB module
        out = self.conv3(out)
        out = self.simam(out)
        out = out * x1
        out2 = x - out
        return out2