'@Author:NavinKumarMNK'
import pytorch_lightning as pl
import torch
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../../')))
import utils.utils as utils

import PIL
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import random_split, DataLoader

class MalBinImgDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.beg_annotation = root_dir + '/malbinimgannotation_beg.txt'
        self.mal_annotation = root_dir + '/malbinimgannotation_mal.txt' 
        self.beg_annotation = open(self.beg_annotation, 'r').readlines()
        self.mal_annotation = open(self.mal_annotation, 'r').readlines() 
        self.alt = 0

    def __len__(self):
        return max(len(self.beg_annotation) , len(self.mal_annotation))
    
    def __getitem__(self, index):
        if self.alt == 0:
            img = self.mal_annotation[index]
            label = 1
            self.alt = 1
        else:
            index = index % len(self.beg_annotation)
            img = self.beg_annotation[index]
            label = 0
            self.alt = 0

        img = PIL.Image.open(img[:-1])
        img = self.transform(img)
        label = int(label)
        return img, label

class MalBinImgDataModule(pl.LightningDataModule):
    def __init__(self, root_dir,  batch_size, num_workers, val_split=0.1, test_split=0.1):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        self.test_split = test_split
        
    
    def prepare_data(self):
        pass
    
    def setup(self, stage=None):
        full_dataset = MalBinImgDataset(self.root_dir)
        total_len = len(full_dataset)
        val_len = int(self.val_split * total_len)
        test_len = int(self.test_split * total_len)
        train_len = total_len - val_len - test_len

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            full_dataset, [train_len, val_len, test_len]
        )

    def train_dataloader(self):
        train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,

            )
        # 
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=True,
        )
        return val_loader
    
    def test_dataloader(self):
        test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers
        )
        return test_loader

if (__name__ == '__main__'):
    data_parmas = utils.config_parse('MALBINIMG_DATASET')
    data = MalBinImgDataset(data_parmas['root_dir'])
    print(len(data))

    datamodule = MalBinImgDataModule(data_parmas['root_dir'], data_parmas['batch_size'], data_parmas['num_workers'])
    datamodule.setup()
    data = datamodule.train_dataloader()
    for (x, y) in data:
        print(x, y)
        break

