import numpy as np
import pandas as pd
import argparse

from castle.algorithms import Notears,PC, GES, TTPM


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-method", type=str, required=True, help="method string ")
    argparser.add_argument("-alarms", required=True, type=str, help="path to alarms csv")
    argparser.add_argument("-output", required=True, type=str, help="path to output file")
    argparser.add_argument("-time-win-size", type=int, help="time window size, optional used for data transformation", default=300)
    argparser.add_argument("-max-ttpm-hops", type=int, help="max hops parameter for ttpm", default=0)
    argparser.add_argument("-topology", type=str, help="path to topology file", default=None)

    args = argparser.parse_args()

    alarms = pd.read_csv(args.alarms)
    
    transform_data = ""

    if args.method == "notears":
        model = Notears()
        transform_data = "nonTime"
    elif args.method == "pc":
        model = PC()
        transform_data = "nonTime"
    elif args.method == "ges":
        model = GES()
        transform_data = "nonTime"
    elif args.method == "ttpm":
        empty_top = np.zeros((1,1)) if args.topology is None else np.load(args.topology)
        model = TTPM(empty_top, max_hop=args.max_ttpm_hops)
        transform_data = "ttpm"
    else:
        raise ValueError("method not recognized")
    
    if transform_data == "nonTime":
        time_win_size = args.time_win_size
        alarms['win_id'] = alarms['start_timestamp'].map(lambda elem:int(elem/time_win_size))
        samples=alarms.groupby(['alarm_id','win_id'])['start_timestamp'].count().unstack('alarm_id')
        samples = samples.dropna(how='all').fillna(0)
        samples = samples.sort_index(axis=1)
    elif transform_data == "ttpm":
        alarms.rename(columns={'alarm_id':'event', 'start_timestamp':'timestamp', 'device_id':'node'}, inplace=True)
        alarms["node"] = 0
        samples = alarms
    


    model.learn(samples)
    np.save(args.output, model.causal_matrix)