#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include <pthread.h>
#include <vector>
#include <string>
#include <map>
#include <set>

#define MAX_STRING 1000

double sigmoid(double x)
{
    return 1.0 / (1.0 + exp(-x));
}

double inv_sigmoid(double x)
{
    return -log(1.0 / x - 1.0);
}

struct Triplet
{
    int h, t, r;
    
    friend bool operator < (Triplet u, Triplet v)
    {
        if (u.r == v.r)
        {
            if (u.h == v.h)
                return u.t < v.t;
            return u.h < v.h;
        }
        return u.r < v.r;
    }
};
struct Neighbor
{
    int e, r;
};

struct Rule
{
    int r1, r2, r;
    char type;
    double score, weight, grad;
    
    Rule()
    {
        score = 0;
        weight = 0;
        grad = 0;
    }
    
    friend bool operator < (Rule u, Rule v)
    {
        if (u.type == v.type)
        {
            if (u.r == v.r)
            {
                if (u.r1 == v.r1)
                    return u.r2 < v.r2;
                return u.r1 < v.r1;
            }
            return u.r < v.r;
        }
        return u.type < v.type;
    }
};

struct Candidate
{
    int h, r, t;
    double truth, logit;
    std::vector<int> ruleids;
    
    Candidate()
    {
        truth = 0;
        logit = 0;
        ruleids.clear();
    }
    
    void init()
    {
        truth = 0;
        logit = 0;
        ruleids.clear();
    }
};

char triplet_file[MAX_STRING], score_file[MAX_STRING], output_rule_file[MAX_STRING], output_triplet_file[MAX_STRING], output_candidate_file[MAX_STRING];
int entity_size = 0, relation_size = 0, triplet_size = 0, rule_size = 0, candidate_size = 0, iterations = 10;
double threshold = 0, learning_rate = 0.01;
std::map<std::string, int> ent2id, rel2id;
std::vector<std::string> id2ent, id2rel;
std::vector<Triplet> trips;
std::vector<Neighbor> *nbs_h2rt, *nbs_t2rh;
std::set<Rule> set_rules;
std::vector<Rule> vec_rules;
std::set<Triplet> valid;
std::set<Candidate> set_cands;
std::map< Triplet, std::vector<int> > trip2rules;
std::map<Triplet, double> trip2score;
std::vector<Candidate> vec_cands;

void ReadData()
{
    char s_head[MAX_STRING], s_tail[MAX_STRING], s_rel[MAX_STRING];
    int h, t, r;
    Triplet trip;
    Neighbor nb;
    std::map<std::string, int>::iterator iter;
    FILE *fi;
    
    fi = fopen(triplet_file, "rb");
    while (1)
    {
        if (fscanf(fi, "%s %s %s", s_head, s_rel, s_tail) != 3) break;
        
        if (ent2id.count(s_head) == 0)
        {
            ent2id[s_head] = entity_size;
            id2ent.push_back(s_head);
            entity_size += 1;
        }
        
        if (ent2id.count(s_tail) == 0)
        {
            ent2id[s_tail] = entity_size;
            id2ent.push_back(s_tail);
            entity_size += 1;
        }
        
        if (rel2id.count(s_rel) == 0)
        {
            rel2id[s_rel] = relation_size;
            id2rel.push_back(s_rel);
            relation_size += 1;
        }
        
        h = ent2id[s_head]; t = ent2id[s_tail]; r = rel2id[s_rel];
        trip.h = h; trip.t = t; trip.r = r;
        trips.push_back(trip);
        valid.insert(trip);
    }
    fclose(fi);
    
    triplet_size = int(trips.size());
    
    nbs_h2rt = new std::vector<Neighbor> [entity_size];
    nbs_t2rh = new std::vector<Neighbor> [entity_size];
    for (int k = 0; k != triplet_size; k++)
    {
        nb.e = trips[k].t;
        nb.r = trips[k].r;
        nbs_h2rt[trips[k].h].push_back(nb);
        nb.e = trips[k].h;
        nb.r = trips[k].r;
        nbs_t2rh[trips[k].t].push_back(nb);
    }
    
    printf("Entity size: %d\n", entity_size);
    printf("Relation size: %d\n", relation_size);
    printf("Triplet size: %d\n", triplet_size);
}

void SearchComposition(int h, int t, int r)
{
    int len1, len2, mid, r1, r2;
    Rule rule;
    
    len1 = int(nbs_h2rt[h].size());
    for (int k = 0; k != len1; k++)
    {
        mid = nbs_h2rt[h][k].e;
        r1 = nbs_h2rt[h][k].r;
        
        len2 = int(nbs_h2rt[mid].size());
        for (int i = 0; i != len2; i++)
        {
            if (nbs_h2rt[mid][i].e != t) continue;
            
            r2 = nbs_h2rt[mid][i].r;
            rule.r1 = r1; rule.r2 = r2; rule.r = r; rule.type = 'c';
            
            set_rules.insert(rule);
        }
    }
}

void SearchSymmetric(int h, int t, int r)
{
    int len;
    Rule rule;
    
    len = int(nbs_h2rt[t].size());
    for (int k = 0; k != len; k++)
    {
        if (nbs_h2rt[t][k].r != r) continue;
        if (nbs_h2rt[t][k].e != h) continue;
        
        rule.r1 = r; rule.r2 = r; rule.r = -1; rule.type = 's';
        
        set_rules.insert(rule);
    }
}

void SearchInverse(int h, int t, int r)
{
    int len, invr;
    Rule rule;
    
    len = int(nbs_h2rt[t].size());
    for (int k = 0; k != len; k++)
    {
        if (nbs_h2rt[t][k].r == r) continue;
        if (nbs_h2rt[t][k].e != h) continue;
        
        invr = nbs_h2rt[t][k].r;
        
        rule.r1 = invr; rule.r2 = r; rule.r = -1; rule.type = 'i';
        
        set_rules.insert(rule);
    }
}

void SearchSubrelation(int h, int t, int r)
{
    int len, subr;
    Rule rule;
    
    len = int(nbs_h2rt[h].size());
    for (int k = 0; k != len; k++)
    {
        if (nbs_h2rt[h][k].e != t) continue;
        
        subr = nbs_h2rt[h][k].r;
        if (subr == r) continue;
        
        rule.r1 = subr; rule.r2 = r; rule.r = -1; rule.type = 'b';
        
        set_rules.insert(rule);
    }
}

void Search()
{
    for (int k = 0; k != triplet_size; k++)
    {
        SearchComposition(trips[k].h, trips[k].t, trips[k].r);
        SearchSymmetric(trips[k].h, trips[k].t, trips[k].r);
        SearchSubrelation(trips[k].h, trips[k].t, trips[k].r);
        SearchInverse(trips[k].h, trips[k].t, trips[k].r);
    }
    
    std::set<Rule>::iterator iter;
    for (iter = set_rules.begin(); iter != set_rules.end(); iter++)
    {
        vec_rules.push_back(*iter);
    }
    
    rule_size = int(set_rules.size());
    set_rules.clear();
    printf("Rule size: %d\n", rule_size);
}

double PrecisionComposition(Rule rule)
{
    int len, h, mid, t;
    int r1, r2, r;
    double p = 0, q = 0;
    Triplet trip;
    
    r1 = rule.r1;
    r2 = rule.r2;
    r = rule.r;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        mid = trips[k].t;
        
        for (int i = 0; i != int(nbs_h2rt[mid].size()); i++)
        {
            if (nbs_h2rt[mid][i].r != r2) continue;
            
            t = nbs_h2rt[mid][i].e;
            
            trip.h = h; trip.r = r; trip.t = t;
            
            if (valid.count(trip) != 0) p += 1;
            q += 1;
        }
    }
    
    return p / q;
}

double PrecisionSymmetric(Rule rule)
{
    int h, r1, r2, t, len;
    double p = 0, q = 0;
    Triplet trip;
    
    r1 = rule.r1;
    r2 = rule.r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = t; trip.t = h; trip.r = r2;
        
        if (valid.count(trip) != 0) p += 1;
        q += 1;
    }
    
    return p / q;
}

double PrecisionInverse(Rule rule)
{
    int h, r1, r2, t, len;
    double p = 0, q = 0;
    Triplet trip;
    
    r1 = rule.r1;
    r2 = rule.r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = t; trip.t = h; trip.r = r2;
        
        if (valid.count(trip) != 0) p += 1;
        q += 1;
    }
    
    return p / q;
}

double PrecisionSubrelation(Rule rule)
{
    int h, r1, r2, t, len;
    double p = 0, q = 0;
    Triplet trip;
    
    r1 = rule.r1;
    r2 = rule.r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = h; trip.t = t; trip.r = r2;
        
        if (valid.count(trip) != 0) p += 1;
        q += 1;
    }
    
    return p / q;
}

void Precision()
{
    for (int k = 0; k != rule_size; k++)
    {
        if (k % 1000 == 0) printf("%d\n", k);
        
        if (vec_rules[k].type == 'c') vec_rules[k].score = PrecisionComposition(vec_rules[k]);
        if (vec_rules[k].type == 's') vec_rules[k].score = PrecisionSymmetric(vec_rules[k]);
        if (vec_rules[k].type == 'i') vec_rules[k].score = PrecisionInverse(vec_rules[k]);
        if (vec_rules[k].type == 'b') vec_rules[k].score = PrecisionSubrelation(vec_rules[k]);
    }
    
    std::vector<Rule> new_rules;
    for (int k = 0; k != rule_size; k++)
    {
        if (vec_rules[k].score < threshold) continue;
        new_rules.push_back(vec_rules[k]);
    }
    
    vec_rules = new_rules;
    rule_size = int(vec_rules.size());
    new_rules.clear();
    
    printf("New Rule size: %d\n", rule_size);
}

void CandidateComposition(int ruleid)
{
    int len, h, mid, t;
    int r1, r2, r;
    Triplet trip;
    
    r1 = vec_rules[ruleid].r1;
    r2 = vec_rules[ruleid].r2;
    r = vec_rules[ruleid].r;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        mid = trips[k].t;
        
        for (int i = 0; i != int(nbs_h2rt[mid].size()); i++)
        {
            if (nbs_h2rt[mid][i].r != r2) continue;
            
            t = nbs_h2rt[mid][i].e;
            
            trip.h = h; trip.r = r; trip.t = t;
            
            trip2rules[trip].push_back(ruleid);
        }
    }
}

void CandidateSymmetric(int ruleid)
{
    int h, r1, r2, t, len;
    Triplet trip;
    
    r1 = vec_rules[ruleid].r1;
    r2 = vec_rules[ruleid].r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = t; trip.t = h; trip.r = r2;
        
        trip2rules[trip].push_back(ruleid);
    }
}

void CandidateInverse(int ruleid)
{
    int h, r1, r2, t, len;
    Triplet trip;
    
    r1 = vec_rules[ruleid].r1;
    r2 = vec_rules[ruleid].r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = t; trip.t = h; trip.r = r2;
        
        trip2rules[trip].push_back(ruleid);
    }
}

void CandidateSubrelation(int ruleid)
{
    int h, r1, r2, t, len;
    Triplet trip;
    
    r1 = vec_rules[ruleid].r1;
    r2 = vec_rules[ruleid].r2;
    
    len = int(trips.size());
    for (int k = 0; k != len; k++)
    {
        if (trips[k].r != r1) continue;
        
        h = trips[k].h;
        t = trips[k].t;
        
        trip.h = h; trip.t = t; trip.r = r2;
        
        trip2rules[trip].push_back(ruleid);
    }
}

void GetCandidate()
{
    for (int k = 0; k != rule_size; k++)
    {
        if (k % 1000 == 0) printf("%d\n", k);
        
        if (vec_rules[k].type == 'c') CandidateComposition(k);
        if (vec_rules[k].type == 's') CandidateSymmetric(k);
        if (vec_rules[k].type == 'i') CandidateInverse(k);
        if (vec_rules[k].type == 'b') CandidateSubrelation(k);
    }
    
    std::map< Triplet, std::vector<int> >::iterator iter;
    Triplet trip;
    Candidate cand;
    for (iter = trip2rules.begin(); iter != trip2rules.end(); iter++)
    {
        cand.init();
        
        trip = iter->first;
        if (valid.count(trip) != 0) cand.truth = 1;
        else cand.truth = 0;
        cand.h = trip.h;
        cand.t = trip.t;
        cand.r = trip.r;
        cand.ruleids = iter->second;
        vec_cands.push_back(cand);
    }
    
    candidate_size = int(vec_cands.size());
    printf("Candidate size: %d\n", candidate_size);
}

void ReadScore()
{
	if (score_file[0] == 0) return;

	char s_head[MAX_STRING], s_tail[MAX_STRING], s_rel[MAX_STRING];
    double score;
    Triplet trip;

	FILE *fi = fopen(score_file, "rb");
	while (1)
	{
		if (fscanf(fi, "%s %s %s %lf", s_head, s_rel, s_tail, &score) != 4) break;

		if (ent2id.count(s_head) == 0) continue;
		if (ent2id.count(s_tail) == 0) continue;
		if (rel2id.count(s_rel) == 0) continue;

		trip.h = ent2id[s_head];
		trip.t = ent2id[s_tail];
		trip.r = rel2id[s_rel];

		trip2score[trip] = score;
	}
	fclose(fi);

	int cn = 0;
	for (int k = 0; k != candidate_size; k++)
	{
		if (vec_cands[k].truth != 0) continue;

		trip.h = vec_cands[k].h;
		trip.t = vec_cands[k].t;
		trip.r = vec_cands[k].r;

		if (trip2score.count(trip) == 0) continue;
        //if (trip2score[trip] < 0.9) vec_cands[k].truth = 0;
        //else vec_cands[k].truth = 1;
		vec_cands[k].truth = trip2score[trip];
		cn += 1;
	}

	printf("%d\n", cn);
}

void InitWeight()
{
    for (int k = 0; k != rule_size; k++) 
    	//vec_rules[k].weight = (rand() / double(RAND_MAX) - 0.5) / 100;
        vec_rules[k].weight = inv_sigmoid(vec_rules[k].score);
}

double Train(double lr)
{
	double error = 0;

    for (int k = 0; k != rule_size; k++) vec_rules[k].grad = 0;
    
    for (int k = 0; k != candidate_size; k++)
    {
        int len = int(vec_cands[k].ruleids.size());
        
        vec_cands[k].logit = 0;
        for (int i = 0; i != len; i++)
        {
            int ruleid = vec_cands[k].ruleids[i];
            vec_cands[k].logit += vec_rules[ruleid].weight / len;
        }
        
        vec_cands[k].logit = sigmoid(vec_cands[k].logit);
        for (int i = 0; i != len; i++)
        {
            int ruleid = vec_cands[k].ruleids[i];
            vec_rules[ruleid].grad += (vec_cands[k].truth - vec_cands[k].logit) / len;
        }

        error += (vec_cands[k].truth - vec_cands[k].logit) * (vec_cands[k].truth - vec_cands[k].logit);
    }
    
    for (int k = 0; k != rule_size; k++) vec_rules[k].weight += lr * vec_rules[k].grad;

    return sqrt(error / candidate_size);
}

void OutputRule()
{
    if (output_rule_file[0] == 0) return;

    FILE *fo = fopen(output_rule_file, "wb");
    for (int k = 0; k != rule_size; k++)
    {
        char type = vec_rules[k].type;
        int r1 = vec_rules[k].r1;
        int r2 = vec_rules[k].r2;
        int r = vec_rules[k].r;
        double weight = vec_rules[k].weight;
        std::string none = "None";
        
        if (r != -1)
        {
            fprintf(fo, "%c\t%s\t%s\t%s\t%lf\n", type, id2rel[r1].c_str(), id2rel[r2].c_str(), id2rel[r].c_str(), weight);
        }
        else
        {
            fprintf(fo, "%c\t%s\t%s\t%s\t%lf\n", type, id2rel[r1].c_str(), id2rel[r2].c_str(), none.c_str(), weight);
        }
    }
    fclose(fo);
}

void OutputTriplet()
{
    if (output_triplet_file[0] == 0) return;

    FILE *fo = fopen(output_triplet_file, "wb");
    for (int k = 0; k != candidate_size; k++)
    {
        int h = vec_cands[k].h;
        int t = vec_cands[k].t;
        int r = vec_cands[k].r;
        double prob = vec_cands[k].logit;
        
        fprintf(fo, "%s\t%s\t%s\t%lf\n", id2ent[h].c_str(), id2rel[r].c_str(), id2ent[t].c_str(), prob);
    }
    fclose(fo);
}

void OutputCandidate()
{
    if (output_candidate_file[0] == 0) return;

    FILE *fo = fopen(output_candidate_file, "wb");
    for (int k = 0; k != candidate_size; k++)
    {
        int h = vec_cands[k].h;
        int t = vec_cands[k].t;
        int r = vec_cands[k].r;
        
        fprintf(fo, "%s\t%s\t%s\n", id2ent[h].c_str(), id2rel[r].c_str(), id2ent[t].c_str());
    }
    fclose(fo);
}

void TrainModel()
{
    ReadData();
    Search();
    Precision();
    GetCandidate();
    ReadScore();
    InitWeight();
    for (int k = 0; k != iterations; k++)
    {
        double error = Train(learning_rate);
        printf("Iteration: %d %lf\n", k, error);
    }
    OutputRule();
    OutputTriplet();
    OutputCandidate();
}

int ArgPos(char *str, int argc, char **argv) {
    int a;
    for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
        if (a == argc - 1) {
            printf("Argument missing for %s\n", str);
            exit(1);
        }
        return a;
    }
    return -1;
}

int main(int argc, char **argv) {
    int i;
    if (argc == 1) {
        return 0;
    }
    score_file[0] = 0;
    output_rule_file[0] = 0;
    output_triplet_file[0] = 0;
    output_candidate_file[0] = 0;
    if ((i = ArgPos((char *)"-triplet", argc, argv)) > 0) strcpy(triplet_file, argv[i + 1]);
    if ((i = ArgPos((char *)"-score", argc, argv)) > 0) strcpy(score_file, argv[i + 1]);
    if ((i = ArgPos((char *)"-outrule", argc, argv)) > 0) strcpy(output_rule_file, argv[i + 1]);
    if ((i = ArgPos((char *)"-outtrip", argc, argv)) > 0) strcpy(output_triplet_file, argv[i + 1]);
    if ((i = ArgPos((char *)"-outcand", argc, argv)) > 0) strcpy(output_candidate_file, argv[i + 1]);
    if ((i = ArgPos((char *)"-threshold", argc, argv)) > 0) threshold = atof(argv[i + 1]);
    if ((i = ArgPos((char *)"-iterations", argc, argv)) > 0) iterations = atoi(argv[i + 1]);
    if ((i = ArgPos((char *)"-lr", argc, argv)) > 0) learning_rate = atof(argv[i + 1]);
    TrainModel();
    return 0;
}
