import glob
import os 
from PIL import Image
import torchvision.transforms as transforms
import torch.utils.data as udata 

class ImageDataset(udata.Dataset):
    """Image dataset loader for denoising tasks"""
    def __init__(self,path,noisy_image="",clean_image="",temps=None):
        # Initialize dataset with image paths
        self.noisy_image = sorted(list(glob.glob(os.path.join(path,noisy_image)+'*_real.png')))
        self.clean_image= sorted(list(glob.glob(os.path.join(path, noisy_image) + '*_mean.png')))
    def __getitem__(self,index):
        # Load and return a data pair (noisy, clean) by index
        noisy = Image.open(self.noisy_image)
        clean = Image.open(self.noisy_image)
        temps = transforms.Compose([transforms.ToTensor()])
        # Handle size mismatch
        if(noisy.size!= clean.size):
            noisy_tfm = transforms.Compose([transforms.Resize((clean.size[1],clean.size[0])),transforms.ToTensor()])
            noisy = noisy_tfm(noisy)
        else:
            noisy = temps(noisy)
        clean = temps(clean)
        return {'input':noisy,'target':clean}
    def __len__(self):
        return len(self.noisy_image)
