S = 1000;
A = 1000;
d = 5;

deff = 10;
std_noise_reward = 1e-2;
N_steps = 20;
N_experiments = 30;

uniform_list_inf = zeros(N_experiments,N_steps);
uniform_list_fro = zeros(N_experiments,N_steps);

leveraged_list_inf = zeros(N_experiments,N_steps);
leveraged_list_fro = zeros(N_experiments,N_steps);

leveraged_oracle_list_inf = zeros(N_experiments,N_steps);
leveraged_oracle_list_fro = zeros(N_experiments,N_steps);

svd_list_inf = zeros(N_experiments,N_steps);
svd_list_fro = zeros(N_experiments,N_steps);

N_samples_arr = ceil(2*(S+A)*deff*linspace(1,50,N_steps));

for i_experiment = 1:N_experiments
    disp(i_experiment)
    if mod(i_experiment,5)==1
        rlow = (randi(100,[S,A])-10)/100 + 20*rand([S,A]).*(rand([S,A])>0.99);
        Sigmar = eye([d,d]);
        [Utmp, ~, Vtmp] = svds(rlow,d);
        Ur = Utmp(:,1:d);
        Vr = Vtmp(:,1:d);
        Qstar = Ur*Sigmar*Vr';
    end

   % find anchors
    ellU = zeros([1,S]);
    ellW = zeros([1,A]);
    for i = 1:S
        ellU(i) = norm(Ur(i,:));
    end
    for j = 1:A
        ellW(j) = norm(Vr(j,:));
    end
    
    probsU = ellU/sum(ellU);
    [~,anchor_states_oracle] = maxk(probsU,deff);
    D = diag(1./sqrt(probsU(anchor_states_oracle)));
    
    probsW = ellW/sum(ellW);
    [~,anchor_actions_oracle] = maxk(probsW,deff);
    DW = diag(1./sqrt(probsW(anchor_actions_oracle)));
    

    error_inf_uniform = zeros(1,N_steps);
    error_fro_uniform = zeros(1,N_steps);
    error_inf_leveraged = zeros(1,N_steps);
    error_fro_leveraged = zeros(1,N_steps);
    error_inf_oracle = zeros(1,N_steps);
    error_fro_oracle = zeros(1,N_steps);
    error_inf_svd = zeros(1,N_steps);
    error_fro_svd = zeros(1,N_steps);

    %% VI ITERATION with uniform anchors
    for curr_epoch = 1:N_steps
        N_samples = ceil(N_samples_arr(curr_epoch)/((S+A)*deff));
    
        anchor_states = randperm(S,deff);
        anchor_actions = randperm(A,deff);
    
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            for i_a = 1:A
                Q(anchor_states(i_s),i_a) = Qstar(anchor_states(i_s),i_a) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        for i_a = 1:length(anchor_actions)
            for i_s = 1:S
                Q(i_s,anchor_actions(i_a)) = Qstar(i_s,anchor_actions(i_a)) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = max(error_inf_uniform(curr_epoch-1)*1e-3);
        end
        my_pinv_mat = my_pinv(epsilon_parameter, Q(anchor_states,anchor_actions));
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*my_pinv_mat*Q(anchor_states,i_a);
            end
        end
        error_inf_uniform(curr_epoch) = max(abs(Qstar(:) - Q(:)));
        error_fro_uniform(curr_epoch) = norm(Qstar-Q,'fro');
    end
    
   %% VI ITERATION with oracle leveraged anchors
    for curr_epoch = 1:N_steps
        N_samples = ceil(N_samples_arr(curr_epoch)/((S+A)*deff));
    
        anchor_states = anchor_states_oracle;
        anchor_actions = anchor_actions_oracle;
    
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            for i_a = 1:A
                Q(anchor_states(i_s),i_a) = Qstar(anchor_states(i_s),i_a) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        for i_a = 1:length(anchor_actions)
            for i_s = 1:S
                Q(i_s,anchor_actions(i_a)) = Qstar(i_s,anchor_actions(i_a)) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = max(error_inf_uniform(curr_epoch-1)*1e-3);
        end
        my_pinv_mat = my_pinv(epsilon_parameter, Q(anchor_states,anchor_actions));
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*my_pinv_mat*Q(anchor_states,i_a);
            end
        end
        error_inf_oracle(curr_epoch) = max(abs(Qstar(:) - Q(:)));
        error_fro_oracle(curr_epoch) = norm(Qstar-Q,'fro');
    end
    
    %% VI ITERATION with leveraged anchors
    for curr_epoch = 1:N_steps
        N_samples = ceil(N_samples_arr(curr_epoch)/(2*(S+A)*deff));
    
        % find anchors
        visited_states = randi(S*A,[1,ceil(N_samples_arr(curr_epoch)/2)]);
        acc_states = accumarray(visited_states(:),1);
        acc_states(length(acc_states):S*A) = 0;
    
        Qtmp = zeros([S,A]);
        for i_s = 1:S
            for i_a = 1:A
                N_tmp = acc_states(i_a+(i_s-1)*A);
                if N_tmp > 0
                    Qtmp(i_s,i_a) = Qstar(i_s,i_a) + mean(randn([1,N_tmp])*std_noise_reward);
                end
            end
        end
        

        [UQ,SigmaQ,WQ] = svds(Qtmp,d);
        UQd = UQ(:,1:d);
        WQd = WQ(:,1:d);
        ellU = zeros([1,S]);
        ellW = zeros([1,A]);
        for i = 1:S
            ellU(i) = norm(UQd(i,:));
        end
        for j = 1:A
            ellW(j) = norm(WQd(j,:));
        end
        
        probsU = ellU/sum(ellU);
        [~,anchor_states] = maxk(probsU,deff);
        D = diag(1./sqrt(probsU(anchor_states)));
        
        probsW = ellW/sum(ellW);
        [~,anchor_actions] = maxk(probsW,deff);
        DW = diag(1./sqrt(probsW(anchor_actions)));
    
        Q = zeros([S,A]);
        for i_s = 1:length(anchor_states)
            for i_a = 1:A
                Q(anchor_states(i_s),i_a) = Qstar(anchor_states(i_s),i_a) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        for i_a = 1:length(anchor_actions)
            for i_s = 1:S
                Q(i_s,anchor_actions(i_a)) = Qstar(i_s,anchor_actions(i_a)) + mean(randn([1,N_samples])*std_noise_reward);
            end
        end
        
        if curr_epoch == 1
            epsilon_parameter = 1e-3;
        else
            epsilon_parameter = max(error_inf_uniform(curr_epoch-1)*1e-3);
        end
        my_pinv_mat = my_pinv(epsilon_parameter, Q(anchor_states,anchor_actions));
    
        for i_s = 1:S
            if any(i_s==anchor_states)
                continue;
            end
            for i_a = 1:A
                if any(i_a==anchor_actions)
                    continue;
                end
                Q(i_s,i_a) = Q(i_s,anchor_actions)*my_pinv_mat*Q(anchor_states,i_a);
            end
        end
        error_inf_leveraged(curr_epoch) = max(abs(Qstar(:) - Q(:)));
        error_fro_leveraged(curr_epoch) = norm(Qstar-Q,'fro');
    end

    %% SVD
    for curr_epoch = 1:N_steps
%         disp(curr_epoch)
        N_samples = N_samples_arr(curr_epoch);
    
        % find anchors
        visited_states = randi(S*A,[1,N_samples]);
        acc_states = accumarray(visited_states(:),1);
        acc_states(length(acc_states):S*A) = 0;
    
        Qtmp = zeros([S,A]);
        for i_s = 1:S
            for i_a = 1:A
                N_tmp = acc_states(i_a+(i_s-1)*A);
                if N_tmp > 0
                    Qtmp(i_s,i_a) = Qstar(i_s,i_a) + mean(randn([1,N_tmp])*std_noise_reward);
                end
            end
        end
        [Utmp, Sigmar, Vtmp] = svds(Qtmp,d);
        Q = Utmp(:,1:d)*Sigmar(1:d,1:d)*Vtmp(:,1:d)';

        error_inf_svd(curr_epoch) = max(abs(Qstar(:) - Q(:)));
        error_fro_svd(curr_epoch) = norm(Qstar-Q,'fro');
    end


    %%

    uniform_list_inf(i_experiment,:) = error_inf_uniform;
    uniform_list_fro(i_experiment,:) = error_fro_uniform;
    
    leveraged_list_inf(i_experiment,:) = error_inf_leveraged;
    leveraged_list_fro(i_experiment,:) = error_fro_leveraged;
    
    leveraged_oracle_list_inf(i_experiment,:) = error_inf_oracle;
    leveraged_oracle_list_fro(i_experiment,:) = error_fro_oracle;
    
    svd_list_inf(i_experiment,:) = error_inf_svd;
    svd_list_fro(i_experiment,:) = error_fro_svd;

end

mean_uniform_inf = mean(uniform_list_inf);
mean_uniform_fro = mean(uniform_list_fro);

mean_leveraged_inf = mean(leveraged_list_inf);
mean_leveraged_fro = mean(leveraged_list_fro);

mean_oracle_inf = mean(leveraged_oracle_list_inf);
mean_oracle_fro = mean(leveraged_oracle_list_fro);

mean_svd_inf = mean(svd_list_inf);
mean_svd_fro = mean(svd_list_fro);


%% Inf norm
figure()
set(gcf,'renderer','Painters')
x = N_samples_arr;
x2 = [x, fliplr(x)];

curve1 = mean_uniform_inf + std(uniform_list_inf);
curve2 = mean_uniform_inf - std(uniform_list_inf);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'b','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_leveraged_inf + std(leveraged_list_inf);
curve2 = mean_leveraged_inf - std(leveraged_list_inf);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'r','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_oracle_inf + std(leveraged_oracle_list_inf);
curve2 = mean_oracle_inf - std(leveraged_oracle_list_inf);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'y','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_svd_inf + std(svd_list_inf);
curve2 = mean_svd_inf - std(svd_list_inf);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'm','FaceAlpha',0.05,'LineStyle','none');
hold on

plot(N_samples_arr,mean_uniform_inf)
hold on
plot(N_samples_arr,mean_leveraged_inf);
hold on
plot(N_samples_arr,mean_oracle_inf);
hold on
plot(N_samples_arr,mean_svd_inf);

xlabel("number of samples", 'interpreter', 'latex','FontSize',12)
ylabel("$\Vert \widehat{M} - M^\star \Vert_{\infty}$", 'interpreter', 'latex','FontSize',16)
legend("","","","","uniform anchors","leveraged anchors","oracle anchors","SVD",'interpreter','latex','FontSize',12)
set(gca,'TickLabelInterpreter','latex')

%% Frobenius norm

figure()
set(gcf,'renderer','Painters')
x = N_samples_arr;
x2 = [x, fliplr(x)];

curve1 = mean_uniform_fro + std(uniform_list_fro);
curve2 = mean_uniform_fro - std(uniform_list_fro);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'b','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_leveraged_fro + std(leveraged_list_fro);
curve2 = mean_leveraged_fro - std(leveraged_list_fro);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'r','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_oracle_fro + std(leveraged_oracle_list_fro);
curve2 = mean_oracle_fro - std(leveraged_oracle_list_fro);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'y','FaceAlpha',0.05,'LineStyle','none');
hold on

curve1 = mean_svd_fro + std(svd_list_fro);
curve2 = mean_svd_fro - std(svd_list_fro);
inBetween = [curve1, fliplr(curve2)];
fill(x2, inBetween, 'm','FaceAlpha',0.05,'LineStyle','none');
hold on


plot(N_samples_arr,mean_uniform_fro)
hold on
plot(N_samples_arr,mean_leveraged_fro);
hold on
plot(N_samples_arr,mean_oracle_fro);
hold on
plot(N_samples_arr,mean_svd_fro);

xlabel("number of samples", 'interpreter', 'latex','FontSize',12)
ylabel("$\Vert \widehat{M} - M^\star \Vert_{\mathrm{F}}$", 'interpreter', 'latex','FontSize',16)
legend("","","","","uniform anchors","leveraged anchors","oracle anchors","SVD",'interpreter','latex','FontSize',12)
set(gca,'TickLabelInterpreter','latex')

