import numpy as np
import pandas as pd

from numpy.random import default_rng

import os

def process_csv(csv_path, output_dir, target_name):
    df = pd.read_csv(csv_path)

    y = df[target_name]
    df = df.drop(columns=[target_name])

    df = filter_binary(df)
    y = y.astype('category')
    y = y.cat.codes

    X = df.to_numpy()
    y = y.to_numpy()

    os.makedirs(output_dir, exist_ok=True)

    np.save(f'{output_dir}/X_data', X)
    np.save(f'{output_dir}/y_data', y)

def filter_binary(df):
    for col in df.columns:
        unique = df[col].unique()
        if len(unique) > 2:
            df = df.drop(col, axis=1)
        else:
            df[col] = df[col].astype('category').cat.codes

    return df

def sample_subset(X, y, n, seed, return_inds = False, replace = True):
    rng = default_rng(seed)
    inds = rng.choice(X.shape[0], size=n, replace=replace)

    if return_inds:
        return inds
    else:
        return X[inds], y[inds]

def apply_noise(X, y, p, n, seed, exact, replace=True): #exact=True means exactly floor(p|y|) points are noised
    rng = default_rng(seed)
    
    if not exact:
        noise = rng.choice([0,1], size=y.shape, p=[1-p, p])
    else:
        indices = rng.choice(range(y.shape[0]), size=int(y.shape[0] * p), replace=False)
        noise = np.zeros(y.shape)
        noise[indices] = 1

    y_p = (y + noise) % 2
    X_s, y_s = sample_subset(X, y_p, n, seed * 3, replace=replace)

    return X, y_p, X_s, y_s