# !/usr/bin/env python

import argparse
import random

# This class is used to split users to training, testing and validation set 

class DataCutter(object):
    def __init__(self, inp, train, test, valudation, number):
        self._input = inp
        self._train = train
        self._test = test
        self._validation = validation
        self._number = number

    def cut(self):
        user_behav = dict()
        user_ids = list()
        with open(self._input) as f:
            for line in f:
                arr = line.strip().split(',')
                if len(arr) != 5:
                    break

                if arr[0] not in user_behav:
                    user_ids.append(arr[0])
                    user_behav[arr[0]] = list()

                user_behav[arr[0]].append(line)

        random.shuffle(user_ids)
        test_user_ids = user_ids[:self._number]
        validation_user_ids = user_ids[self._number:(2*self._number)]
        train_user_ids = user_ids[(2*self._number):]

        # write dataset
        with open(self._train, 'wb') as f:
            for uid in train_user_ids:
                for line in user_behav[uid]:
                    f.write(line)

        with open(self._test, 'wb') as f:
            for uid in test_user_ids:
                for line in user_behav[uid]:
                    f.write(line)

        with open(self._validation, 'wb') as f:
            for uid in validation_user_ids:
                for line in user_behav[uid]:
                    f.write(line)


if __name__ == '__main__':
    _PARSER = argparse.ArgumentParser(
        description="DataCutter: cut data to train, test and validation set")
    _PARSER.add_argument("--input", required=True, help="input filename")
    _PARSER.add_argument("--train", required=True,
                         help="filename of output train set")
    _PARSER.add_argument("--test", required=True,
                         help="filename of output test set")
    _PARSER.add_argument("--validation", required=True,
                         help="filename of output validation set")
    _PARSER.add_argument("--number", required=True, type=int,
                         help="number of users for test and validation set")
    _ARGS = _PARSER.parse_args()
    DataCutter(_ARGS.input, _ARGS.train, _ARGS.test, _ARGS.validation, _ARGS.number).cut()
