import librosa
import glob
import os
import torch
import laion_clap
import time

model = laion_clap.CLAP_Module(enable_fusion=True)

model.load_ckpt('./laion_clap_fullset_fusion.pt')
print("loaded model successfully!")
model.cuda()

audio_files = glob.glob(os.path.join('../data/AVE_Dataset/audios', '*.wav'))
audio_files.sort()
print(f'number of audio file: {len(audio_files)}')
# audio_files = audio_files[0:10]

batch_size = 96
split_num = len(audio_files)//batch_size

audio_emb_list = []
with torch.no_grad():
    for i in range(split_num+1):
        if i == split_num:
            audio_files_split = audio_files[batch_size*i:len(audio_files)]
        else:
            audio_files_split = audio_files[batch_size*i:batch_size*(i+1)]
        audio_emb = model.get_audio_embedding_from_filelist(x = audio_files_split)
        audio_emb_list.append(torch.tensor(audio_emb))
        print(f"{batch_size*(i+1)}//{len(audio_files)} shape:{audio_emb.shape}", time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))

    audio_embs = torch.cat(audio_emb_list).cpu()
    torch.save(audio_embs, './embedding/CLAP_AVE_Audio.pt')
    print(audio_embs.shape)