from argparse import ArgumentParser
import os

import pandas as pd
from tqdm.auto import tqdm

PATHS = {
    2019: { # redacted -- placeholders here
        "dme": "",
        "hha": "",
        "medpar": "",
        "op": "",
        "ptb": "",
        "denom": "",
    },
    2018: {
       "dme": "",
        "hha": "",
        "medpar": "",
        "op": "",
        "ptb": "",
        "denom": "",
    },  
}

def get_args():
    psr = ArgumentParser()
    psr.add_argument("--overwrite", action='store_true')
    psr.add_argument("--chunksize", type=int, default=100000)
    psr.add_argument("--n_chunks", type=int, default=None)
    psr.add_argument("--filter-suffix", type=str, default="22")
    psr.add_argument("--year", choices=[2018, 2019], default=2019)
    psr.add_argument("--include-claims", type=str, nargs="+", default=["dme", "hha", "op", "medpar", "ptb", "denom"]) 
    return psr.parse_args()

def scan_and_save(claim_type, year, chunksize=10000, suffix="22", n_chunks=None, overwrite=False):
    path = f"./raw_subset/{year}_{claim_type}.csv"
    if os.path.isfile(path) and not overwrite:
        raise ValueError(f"File exists at {path} and overwrite flag is false.")
    df = pd.read_sas(PATHS[year][claim_type], chunksize=chunksize)
    chunks_seen = 0
    subs = []
    pbar = tqdm(desc=f"Scanning {claim_type} claims", unit="ln", total=df.row_count)
    while True:
        if chunks_seen == n_chunks: 
            print("Read", n_chunks, "chunks -- exiting")
            break
        try:
            chunk = next(df)
            bene_col = "BENE_ID" if "BENE_ID" in chunk.columns else "bene_id"
            
            suffix_mask = chunk[bene_col].str.decode("utf-8").str.endswith(suffix)
            subset = chunk[suffix_mask]
            str_df = subset.select_dtypes([object]).stack().str.decode('utf-8').unstack()
            for col in str_df:
                subset.loc[:, col] = str_df[col]
            subs.append(subset)
            pbar.update(chunksize)
            chunks_seen += 1
        except StopIteration:
            break
    pd.concat(subs).to_csv(path)

if __name__ == '__main__':
    args = get_args()
    for claim_type in args.include_claims:
        if claim_type in PATHS[args.year]:
            scan_and_save(
                claim_type,
                args.year,
                chunksize=args.chunksize,
                suffix=args.filter_suffix,
                n_chunks=args.n_chunks,
                overwrite=args.overwrite
            )

    
