import numpy as np


class SphereSampleManager(object):
    def __init__(self, verts, faces, distance=5, batch_size=-1, return_idx=False):
        verts /= np.sum(verts ** 2, axis=1, keepdims=True) ** .5
        self.used = np.zeros(verts.shape[0], dtype=bool)
        self.adj_dict = self.get_adjunct_dict(faces, verts.shape[0])
        self.verts = verts
        self.distance = distance
        self.front = []
        self.batch_size = batch_size
        self.return_idx = return_idx

    def get_idx(self, target):
        # [n_vert, n_tar]
        this_dis = ((np.expand_dims(target, axis=0) - np.expand_dims(self.verts, axis=1)) ** 2).sum(2)
        return np.argmin(this_dis, axis=0)

    @staticmethod
    def get_adjunct_dict(faces, n):
        get = dict()
        for i in range(n):
            get[i] = set(faces[np.where(faces == i)[0], :].ravel().tolist())
        return get

    def get_init(self, start_point=np.array([[0, 0, 1]])):
        idx = np.argmin(((np.expand_dims(self.verts, axis=0) - np.expand_dims(start_point, axis=1)) ** 2).sum(2), axis=1)
        self.used = np.zeros(self.verts.shape[0], dtype=bool)
        self.used[idx] = True
        self.front = idx.tolist()
        if self.return_idx:
            return self.verts[idx] * self.distance, idx
        else:
            return self.verts[idx] * self.distance

    def next_search_(self, max_size=1e10):
        get = set()

        for i, t in enumerate(self.front):
            this_set_ = {k for k in self.adj_dict[t] if not self.used[k]}
            if len(get) + len(this_set_) > max_size:
                get = get.union(set(list(this_set_)[0:max_size - len(get)]))
                break
            get = get.union(this_set_)
            del self.front[i]

        get = list(get)
        self.front += get
        if len(get) > 0:
            self.used[np.array(get)] = True
        return get

    def next_search(self, max_size=1e10):
        get = self.next_search_(max_size=max_size)
        if len(get) == 0:
            return self.random_get(max_size)
        if self.return_idx:
            return self.verts[np.array(get)] * self.distance, np.array(get)
        else:
            return self.verts[np.array(get)] * self.distance

    def if_maxed(self, bias_m=0):
        return np.sum(self.used) >= self.verts.shape[0] + bias_m

    def random_get(self, n):
        idx = np.random.choice(self.verts.shape[0], n, replace=False)
        if self.return_idx:
            return self.verts[idx] * self.distance, idx
        else:
            return self.verts[idx] * self.distance

    def __next__(self):
        if self.if_maxed():
            if self.batch_size <= 0:
                raise Exception('batch_size must > 0 for getitem')
            return self.random_get(self.batch_size)
        return self.next_search(self.batch_size)

    def __iter__(self):
        return self


def camera_position_to_spherical_angle_np(camera_pose):
    distance_o = np.sum(camera_pose ** 2, axis=1) ** .5
    azimuth_o = np.arctan(camera_pose[:, 0] / camera_pose[:, 2]) % np.pi + np.pi * (camera_pose[:, 0] <= 0).astype(camera_pose.dtype)
    elevation_o = np.arcsin(camera_pose[:, 1] / distance_o)
    return distance_o, elevation_o, azimuth_o


def vertices_filter(verts, faces, filter_func):
    # (boolean) [k, ]
    mask = filter_func(verts)

    # (boolean) [k, f, 3]
    get = np.expand_dims(faces, axis=0) == np.expand_dims(np.where(mask)[0], axis=(1, 2))

    # (boolean) [f, ]
    mask_faces = np.sum(get, axis=(0, 2)) == 3

    masked_face0 = faces[mask_faces]
    masked_face1 = faces[mask_faces]
    for i, k in enumerate(np.where(mask)[0]):
        masked_face0[masked_face1 == k] = i
    # print(masked_face0.max())
    return verts[mask], masked_face0


if __name__ == '__main__':
    from MeshUtils import load_off, save_off
    from PIL import Image, ImageDraw
    import BboxTools as bbt
    from MeshUtils import pre_process_mesh_pascal

    template_path = './UVsamples4p1.off'
    # template_path = './ICOsamples3r4p.off'
    verts, faces = load_off(template_path)
    verts = pre_process_mesh_pascal(verts)

    # template_path = './UVsamples4.off'
    # verts, faces = load_off(template_path)
    # verts = pre_process_mesh_pascal(verts)
    # filter_foo = lambda x: np.logical_and(np.abs(camera_position_to_spherical_angle_np(x)[1]) < np.pi * (0.4 / 18),
    #                                       camera_position_to_spherical_angle_np(x)[1] > -1e-2)
    # verts, faces = vertices_filter(verts, faces, filter_func=filter_foo)
    #
    # save_off('./UVsamples4p1.off', verts, faces)
    #
    # exit()

    verts = pre_process_mesh_pascal(verts)
    manager = SphereSampleManager(verts, faces, batch_size=10)
    # print(manager.get_init(start_point=np.array([[0, 0, 1], [0, 0, -1]])).shape)

    start_point = np.array([[0, 0, 1], [1, 0, 0], [-1, 0, 0]])
    print(manager.get_init(start_point=start_point).shape)

    image = np.zeros((512, 512, 3), dtype=np.uint8)
    principal = np.array([image.shape[0:2]]) / 2

    R = np.array(
        [[-1., 0., 0. ],
         [0.,  1., 0. ],
         [0.,  0., -1.]]
    )
    T = np.array([[0, 0, 12.5]])

    point2d = verts @ R + T
    point2d = principal - np.concatenate([point2d[:, 1:2] / point2d[:, 2:3], point2d[:, 0:1] / point2d[:, 2:3]], axis=1) * 3000

    # print(next(manager))

    get_imgs = []
    try:
        for i, k in enumerate(manager):
                print(k.shape, manager.used.sum(), manager.verts.shape[0], )
                if manager.if_maxed():
                    break
                if i == 200:
                    break
                img = Image.fromarray(image)
                imd = ImageDraw.Draw(img)
                for j, point_ in enumerate(point2d):
                    if manager.used[j]:
                        color = (0, 255, 0)
                    else:
                        color = (255, 0, 0)

                    box = bbt.box_by_shape((5, 5), point_)
                    imd.ellipse(box.pillow_bbox(), fill=color)
                # img.show()

                get_imgs.append(img)
                get_imgs.append(img)
    except:
        pass

    get_imgs[0].save('tem2.gif', save_all=True, append_images=get_imgs[1::])




