import time
import torch

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
from torch.utils.data import DataLoader
from torchsummary import summary
from torchvision import datasets, transforms

# Custom exception for invalid kernel shapes in quaternion convolution.
class InvalidKernelShape(RuntimeError):
  """Base class to generate custom exception if generating kernel failed."""

  def __init__(self, error_message):
    """ Construct custom error with custom error message.
    :param error_message: The custom error message.
    """
    super().__init__(error_message)

# Custom exception for invalid input dimensions in quaternion convolution.
class InvalidInput(RuntimeError):
  """Base class to generate custom exception if input is invalid."""

  def __init__(self, error_message):
    """ Construct custom error with custom error message.
    :param error_message: The custom error message.
    """
    super().__init__(error_message)

# Implements a quaternion-valued convolution layer for neural networks.
class QuaternionConvolution(nn.Module):
    """Reproduction class of the quaternion convolution layer."""

    ALLOWED_DIMENSIONS = (2, 3)


    def __init__(self, in_channels, out_channels, kernel_size, stride, dimension=2, padding=0, dilation=1, groups=1,
                 bias=True):
        super(QuaternionConvolution, self).__init__()

        self.in_channels = np.floor_divide(in_channels, 4)  # 输入信号的通道数
        self.out_channels = np.floor_divide(out_channels, 4)  # 输出信号的通道数

        self.groups = groups
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.kernel_size = self.get_kernel_shape(kernel_size, dimension)
        self.weight_shape = self.get_weight_shape(self.in_channels, self.out_channels, self.kernel_size)

        # print(self.weight_shape)

        self._weights = self.weight_tensors(self.weight_shape, kernel_size)
        # self._weights defines four weight tensors: r_weight, k_weight, i_weight, and j_weight, which represent the real and three imaginary parts of the quaternion, respectively
        self.r_weight, self.k_weight, self.i_weight, self.j_weight = self._weights

        # print(self.r_weight.shape, self.k_weight.shape, self.i_weight.shape, self.j_weight.shape)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
            nn.init.constant_(self.bias, 0)


    def forward(self, x):
        # print(x.size)

        cat_kernels_4_r = torch.cat([self.r_weight, -self.i_weight, -self.j_weight, -self.k_weight], dim=1)
        cat_kernels_4_i = torch.cat([self.i_weight, self.r_weight, -self.k_weight, self.j_weight], dim=1)
        cat_kernels_4_j = torch.cat([self.j_weight, self.k_weight, self.r_weight, -self.i_weight], dim=1)
        cat_kernels_4_k = torch.cat([self.k_weight, -self.j_weight, self.i_weight, self.r_weight], dim=1)
        # print(cat_kernels_4_r.shape)
        cat_kernels_4_quaternion = torch.cat([cat_kernels_4_r, cat_kernels_4_i, cat_kernels_4_j, cat_kernels_4_k],
                                             dim=0)
        # print(x.shape)
        # print(cat_kernels_4_quaternion.shape)

        if x.dim() == 3:
            convfunc = F.conv1d
        elif x.dim() == 4:
            convfunc = F.conv2d
        elif x.dim() == 5:
            convfunc = F.conv3d
        else:
            raise InvalidInput("Given input channels do not match allowed dimensions")

        return convfunc(x, cat_kernels_4_quaternion, self.bias, self.stride, self.padding, self.dilation, self.groups)

    @staticmethod
    # Generates and initializes the 4 weight tensors (r,i,j,k) for quaternion operations.
    def weight_tensors(weight_shape, kernel_size):
        modulus = nn.Parameter(torch.Tensor(*weight_shape))  # Create a tensor of shape weight_shape
        modulus = nn.init.xavier_uniform_(modulus, gain=1.0)  # Initialize it with Xavier Uniform Initialization

        # weight_shape is a tuple representing the shape of the weight tensor
        i_weight = 2.0 * torch.rand(*weight_shape) - 1.0
        j_weight = 2.0 * torch.rand(*weight_shape) - 1.0
        k_weight = 2.0 * torch.rand(*weight_shape) - 1.0

        sum_imaginary_parts = i_weight.abs() + j_weight.abs() + k_weight.abs()

        i_weight = torch.div(i_weight, sum_imaginary_parts)
        j_weight = torch.div(j_weight, sum_imaginary_parts)
        k_weight = torch.div(k_weight, sum_imaginary_parts)

        # A random number generates a phase value and four weight tensors are computed
        phase = torch.rand(*weight_shape) * (2 * torch.tensor([np.pi])) - torch.tensor([np.pi])

        r_weight = modulus * np.cos(phase)
        i_weight = modulus * i_weight * np.sin(phase)
        j_weight = modulus * j_weight * np.sin(phase)
        k_weight = modulus * k_weight * np.sin(phase)

        return nn.Parameter(r_weight), nn.Parameter(i_weight), nn.Parameter(j_weight), nn.Parameter(k_weight)

    @staticmethod
    # Computes the shape of weight tensors based on channels and kernel size.
    def get_weight_shape(in_channels, out_channels, kernel_size):

        return (out_channels, in_channels) + kernel_size

    @staticmethod
    # Validates and returns the kernel shape based on input dimension.
    def get_kernel_shape(kernel_size, dimension):

        if dimension not in QuaternionConvolution.ALLOWED_DIMENSIONS:
            raise InvalidKernelShape('Given dimensions are not allowed.')

        if isinstance(kernel_size, int):
            return (kernel_size,) * dimension

        if isinstance(kernel_size, tuple):
            if len(kernel_size) != dimension:
                raise InvalidKernelShape('Given kernel shape does not match dimension.')

            return kernel_size

        raise InvalidKernelShape('No valid type of kernel size to construct kernel.')

    def __repr__(self):

        return self.__class__.__name__ + '(' \
            + 'in_channels=' + str(self.in_channels) \
            + ', out_channels=' + str(self.out_channels) \
            + ', kernel_size=' + str(self.kernel_size) \
            + ', stride=' + str(self.stride) + ')'

