import random
from typing import List, Tuple, Str, Union

import torch
from PIL import Image
from diffusers import StableDiffusionPipeline
from networkx.classes.reportviews import NodeView


PROMPT_TEMPLATE = "A picture of a {}"

class DataGenerator:
    def __init__(self, *args, **kwargs): 
        raise NotImplementedError

    def generate(self): 
        raise NotImplementedError
    
class StableDiffusionGenerator(DataGenerator):
    def __init__(self, model_name: Str = "runwayml/stable-diffusion-v1-5", device = None, *args, **kwargs):
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.pipe = StableDiffusionPipeline.from_pretrained(model_name).to(self.device)

        self.prompt_template = PROMPT_TEMPLATE

    def generate_prompt(self, class_label: Union[Str, List[Str]], samples: Union[List[NodeView], List[List[NodeView]]]) -> List[Str]:
        prompts = []
        if isinstance(class_label, list) or isinstance(samples[0], list):
            assert len(class_label) == len(samples)
            for label, sample in zip(class_label, samples):
                prompt = self.generate_prompt(label, sample)
                prompts.extend(prompt)
        else:
            desc = ", ".join([node.get() for node in samples])
            prompt = self.prompt_template.format(class_label)
            prompts.append(f"{prompt}, {desc}")
        return prompts
    
    def generate(self, class_label: Union[Str, List[Str]], sample: Union[List[NodeView], List[List[NodeView]]]) -> List[Image.Image]:
        """
        Generates images based on the given class label and sample. Input can be either a single sample result or a list of sample results.

        Args:
            class_label (Union[str, List[str]]): The class label(s) for the generated images.
            sample (Union[List[NodeView], List[List[NodeView]]]): The sample(s) used as input for generating images.

        Returns:
            List[Image.Image]: A list of generated images.

        """
        prompts = self.generate_prompt(class_label, sample)
        img = self.pipe(
            prompts,
            num_images_per_prompt=1,
            num_inference_steps=20
        ).images
        return img
