import sys
import os

rule_pred_file = sys.argv[1]
kge_pred_file = sys.argv[2]
output_file = sys.argv[3]
weight = float(sys.argv[4])

hrt2p = dict()
with open(rule_pred_file, 'r') as fi:
	for line in fi:
		h, r, t, p = line.strip().split('\t')[0:4]
		hrt2p[(h, r, t)] = float(p)

def get_rule_prob(h, r, t):
	if h == t:
		if hrt2p.get((h, r, t), 0) < 0.5:
			return -100
		return hrt2p[(h, r, t)]
	else:
		if (h, r, t) in hrt2p:
			return hrt2p[(h, r, t)]
		return 0.5

hit1 = 0
hit3 = 0
hit10 = 0
mr = 0
mrr = 0
cn = 0

with open(kge_pred_file, 'r') as fi:
	while True:
		truth = fi.readline()
		preds = fi.readline()

		if (not truth) or (not preds):
			break

		truth = truth.strip().split()
		preds = preds.strip().split()

		h, r, t, mode, prev_ranking = truth[0:5]
		prev_ranking = int(prev_ranking)

		if mode == 'h':
			preds = [[pred.split(':')[0], float(pred.split(':')[1])] for pred in preds]

			for k in range(len(preds)):
				e = preds[k][0]
				preds[k][1] += get_rule_prob(e, r, t) * weight

			preds = sorted(preds, key=lambda x:x[1], reverse=True)
			ranking = -1
			for k in range(len(preds)):
				e = preds[k][0]
				if e == h:
					ranking = k + 1
					break
			if ranking == -1:
				ranking = prev_ranking

		if mode == 't':
			preds = [[pred.split(':')[0], float(pred.split(':')[1])] for pred in preds]

			for k in range(len(preds)):
				e = preds[k][0]
				preds[k][1] += get_rule_prob(h, r, e) * weight

			preds = sorted(preds, key=lambda x:x[1], reverse=True)
			ranking = -1
			for k in range(len(preds)):
				e = preds[k][0]
				if e == t:
					ranking = k + 1
					break
			if ranking == -1:
				ranking = prev_ranking

		if ranking <= 1:
			hit1 += 1
		if ranking <=3:
			hit3 += 1
		if ranking <= 10:
			hit10 += 1
		mr += ranking
		mrr += 1.0 / ranking
		cn += 1

hit1 /= cn
hit3 /= cn
hit10 /= cn
mr /= cn
mrr /= cn

print('Hit@1: ', hit1)
print('Hit@3: ', hit3)
print('Hit@10: ', hit10)
print('MR: ', mr)
print('MRR: ', mrr)

with open(output_file, 'w') as fo:
	fo.write('Hit@1: {}\n'.format(hit1))
	fo.write('Hit@3: {}\n'.format(hit3))
	fo.write('Hit@10: {}\n'.format(hit10))
	fo.write('MR: {}\n'.format(mr))
	fo.write('MRR: {}\n'.format(mrr))
