import os
import ase.io
import os.path as osp
from .trajdata import TrajData
from .trajdataset import TrajDataset
import numpy as np
import random
import pickle
import lmdb
import torch
from tqdm import tqdm
from .utils import AtomsToGraphs
from torch_geometric.data import Data
import multiprocessing as mp




def read_trajectory_and_extract_features(a2g, traj_path, all_traj_idx, n_frames_given):
    traj = ase.io.trajectory.Trajectory(traj_path)
    trajs = [traj[idx] for idx in all_traj_idx]

    first_frame_atoms = trajs[0]
    atomic_numbers = torch.Tensor(first_frame_atoms.get_atomic_numbers())
    cell = torch.Tensor(np.array(first_frame_atoms.get_cell())).view(1, 3, 3)
    tags = torch.Tensor(first_frame_atoms.get_tags())
    pos = [torch.Tensor(atoms.get_positions()) for atoms in trajs]
    pos = torch.stack(pos, dim=-1)  # [N, 3, T]

    # construct the edges based on the configuration of the last given frame
    last_given_frame_atoms = trajs[n_frames_given-1]
    split_idx_dist = a2g._get_neighbors_pymatgen(last_given_frame_atoms)
    edge_index, edge_distances, cell_offsets = a2g._reshape_features(
        *split_idx_dist
    )

    data = Data(
        atomic_numbers=atomic_numbers,
        cell=cell,
        tags=tags,
        pos=pos,
        edge_index=edge_index,
        edge_distances=edge_distances,
        cell_offsets=cell_offsets,
    )

    return data


def write_images_to_pickle(mp_arg):
    a2g, db_path, samples, sampled_ids, idx = mp_arg
    all_data = []
    for traj_path in tqdm(samples, total=len(samples)):
        num_frames = len(ase.io.trajectory.Trajectory(traj_path))
        tot_n_frames_considered = 50
        if num_frames < tot_n_frames_considered:
            continue
        n_frames_given, n_frames_pred = 5, 5
        tot_n_frames = n_frames_pred + n_frames_given
        all_traj_idx = [_ * (tot_n_frames_considered // tot_n_frames) for _ in range(tot_n_frames)]
        data = read_trajectory_and_extract_features(a2g, traj_path,
                                                    all_traj_idx,
                                                    n_frames_given=5)

        all_data.append(data)
        idx += 1
        sampled_ids.append(
            traj_path
            + ","
            + str(num_frames)
            + "\n"
        )
    return all_data, sampled_ids, idx


def write_images_to_lmdb(mp_arg):

    a2g, db_path, samples, sampled_ids, idx = mp_arg
    db = lmdb.open(
        db_path,
        map_size=1099511627776 * 2,
        subdir=False,
        meminit=False,
        map_async=True,
    )

    for traj_path in tqdm(samples, total=len(samples)):
        num_frames = len(ase.io.trajectory.Trajectory(traj_path))
        tot_n_frames_considered = 50
        if num_frames < tot_n_frames_considered:
            continue
        n_frames_given, n_frames_pred = 5, 5
        tot_n_frames = n_frames_pred + n_frames_given
        all_traj_idx = [_ * (tot_n_frames_considered // tot_n_frames) for _ in range(tot_n_frames)]
        data = read_trajectory_and_extract_features(a2g, traj_path,
                                                    all_traj_idx,
                                                    n_frames_given=5)

        txn = db.begin(write=True)
        txn.put(
            f"{idx}".encode("ascii"), pickle.dumps(data, protocol=-1)
        )
        txn.commit()
        idx += 1
        sampled_ids.append(
            traj_path
            + ","
            + str(num_frames)
            + "\n"
        )

    # Save count of objects in lmdb.
    txn = db.begin(write=True)
    txn.put("length".encode("ascii"), pickle.dumps(idx, protocol=-1))
    txn.commit()

    db.sync()
    db.close()

    return sampled_ids, idx


class OC(TrajDataset):
    def __init__(self, root, name, raw_path, force_reprocess=False,
                 force_length_train=None, force_length_val=None, force_length_test=None):
        self.raw_path = raw_path
        # Initialize feature extractor.
        a2g = AtomsToGraphs(
            max_neigh=50,
            radius=6,
            r_energy=True,
            r_forces=True,
            r_distances=False,
            r_fixed=True,
        )
        self.a2g = a2g
        self.force_length_train = force_length_train
        self.force_length_val = force_length_val
        self.force_length_test = force_length_test
        super().__init__(root=root, name=name, force_reprocess=force_reprocess)

    def processed_file(self):
        return osp.join(self.root, self.name + '.pt')

    def preprocess_raw(self):
        # Process the data and save into lmdb
        sampled_ids, idx = [], 0
        dp_path = self.processed_file()
        all_files = os.listdir(self.raw_path)
        all_files = [osp.join(self.raw_path, _) for _ in all_files]
        mp_args = (
                self.a2g,
                dp_path,
                all_files,
                sampled_ids,
                idx,
        )
        # sampled_ids, idx = write_images_to_lmdb(mp_args)
        all_data, sampled_ids, idx = write_images_to_pickle(mp_args)
        # print(idx)
        with open(self.processed_file(), 'wb') as f:
            pickle.dump(all_data, f)

    def postprocess(self):
        pass

    def get_split(self):
        split_file_dir = osp.join(self.root, 'split.pkl')
        print(split_file_dir)
        if not osp.exists(split_file_dir):  # if no split file
            # Generate splits by splitting different systems
            indexes = [_ for _ in range(len(self.data))]
            total_n = len(indexes)
            random.Random(42).shuffle(indexes)
            train_indexes = [indexes[_] for _ in range(0, int(total_n * 0.8))]
            val_indexes = [indexes[_] for _ in range(int(total_n * 0.8), int(total_n * 0.9))]
            test_indexes = [indexes[_] for _ in range(int(total_n * 0.9), total_n)]
            with open(split_file_dir, 'wb') as f:
                pickle.dump((train_indexes, val_indexes, test_indexes), f)
                print(f'split file generated and dumped to {split_file_dir}')
        with open(split_file_dir, 'rb') as f:
            print('Got split file')
            train_indexes, val_indexes, test_indexes = pickle.load(f)

        if self.force_length_train is not None:
            train_indexes = train_indexes[: self.force_length_train]
        if self.force_length_val is not None:
            val_indexes = val_indexes[: self.force_length_val]
        if self.force_length_test is not None:
            test_indexes = test_indexes[: self.force_length_test]

        splits = {
            'train': torch.utils.data.Subset(self, train_indexes),
            'val': torch.utils.data.Subset(self, val_indexes),
            'test': torch.utils.data.Subset(self, test_indexes),
        }

        return splits

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data = TrajData(x=data.pos, h=data.atomic_numbers,
                        edge_index=data.edge_index, edge_attr=data.edge_distances.unsqueeze(-1),
                        cell=data.cell, cell_offsets=data.cell_offsets)
        data['system_id'] = torch.ones(1) * idx
        data['neighbors'] = data.edge_index.size(1)
        return data


