import pandas as pd
import inflect
import re
import spacy
import os
import argparse
nlp = spacy.load("en_core_web_sm")


def parse_args():
    parser = argparse.ArgumentParser()

    # In-N-Out Path
    parser.add_argument('--data_dir', type=str, default='/home/prakanss/dataset_scripts_pid/cohyponym_dataset/cohyponym_images', help="Path to directory containing input images")
    parser.add_argument('--prompts_file', type=str, default='', help="Path to file containing prompts and words of interest")
    parser.add_argument('--output_file', type=str, default='../../data/coco/coco_hyponyms.csv', help="Save path")

    args = parser.parse_args()
    return args


def extract_nouns(text):
  """
  This function takes a text string and returns a list of noun phrases.

  Args:
      text: The text string to be processed.

  Returns:
      A list of noun phrases extracted from the text.
  """
  doc = nlp(text)
  nouns = []

  for token in doc:
      if token.pos_ == "NOUN":
          nouns.append(token.text)

  return nouns


def construct_expanded_dataset(data_dir, prompts_file):
    expanded_data = []
    imgs = os.listdir(data_dir)

    if len(prompts_file):
        with open(prompts_file, "r") as f:
            samples = f.readlines()

        for idx, sample in enumerate(samples):
            print(f"Processing {idx}")

            comps = sample.split(',')
            prompt = comps[0]
            words = comps[1:]
            img_path = os.path.join(data_dir, f"{prompt.replace(' ', '_')}.jpg")

            if len(words) < 2:
                continue

            expanded_data.append([img_path, prompt, words])

    else:
        for idx, img in enumerate(imgs):
            print(f"Processing {idx}")

            img_path = os.path.join(data_dir, img)
            caption = img.split(".")[0]
            caption = caption.replace("_", " ")
            nouns = extract_nouns(caption)

            if len(nouns) < 2:
                continue

            expanded_data.append([img_path, caption, nouns])

    print("Number of samples = %d" % (len(expanded_data)))
    df = pd.DataFrame(expanded_data, columns = ['image path', 'caption', 'objs'])

    return df

    
def main():
    args = parse_args()
    df = construct_expanded_dataset(args.data_dir, args.prompts_file)
    df.to_csv(args.output_file, index = False)


if __name__ == "__main__":
    main()
