clear

%% load pretrained features
load data/cifar10resnet.mat

d = 50
%%
% loop over all different classes
cate = unique(YTrain);
for cateid = 1:10
    digit = cate(cateid);
    % loop over all different n_+ settings
    for scale = 1:5
        np = 24*scale; nq = 1000;
        s0 = []; s1 = []; s2 = [];
        parfor seed = 1:96
            rng(seed)
            % randomly sample positive and negative data
            xp0 = XTrain2(:, YTrain == digit);
            xq0 = XTrain2(:, YTrain ~= digit);

            idxp = randsample(size(xp0,2),np,false);
            xp = xp0(:, idxp);

            idxq = randsample(size(xq0,2),nq,false);
            xq = xq0(:, idxq);
            
            % generate training dataset
            X = [xp, xq];
            y = [ones(1, np) -ones(1,nq)];

            %%
            npt = sum(YTest == digit); nqt = sum(YTest ~= digit);
            xpt0 = XTest(:, YTest == digit);
            xqt0 = XTest(:, YTest ~= digit);

            idxpt = randsample(size(xpt0,2),npt,false);
            xpt = xpt0(:, idxpt);

            idxqt = randsample(size(xqt0,2),nqt,false);
            xqt = xqt0(:, idxqt);

            % generate testing dataset
            Xt = [xpt,xqt];
            yt = [ones(1, npt) -ones(1,nqt)];

            %%
            % the proposed method
            dn = 50;
            para = zeros(dn+1,1);
            old_para = ones(dn+1,1) * inf;
            lambda = 0;
            % initiate the weights to ones. 
            w = ones(1,np+nq);
            fea = @(x) x;

            % the first run using (7), without any weights and regularization term.
            opts = optimset('fmincon');
            opts.MaxIter = 100000;
            opts.MaxFunEvals = 1000000;
            opts.GradObj = 'on';
            wp = w(y==1);
            wq = w(y==-1);

            K = fea(X);
            Kp = [K(:,y==1); ones(1,sum(y==1))];
            Kq = [K(:,y==-1); ones(1,sum(y==-1))];

            % objective function 
            obj1 = @(para) rocloss(para, Kp, Kq, wp, wq, lambda);
            n = length(y);
            [para, fval2, oflag] = fmincon(obj1, para,[[K';-K'], [ones(n,1);-ones(n,1)]],...
                [pi/2*ones(n,1); zeros(n,1)], [], [], [], [], [], opts);
            f = @(x) para(1:end-1)'*fea(x) + para(end);
            f = @(x) min(max(f(x), 0), pi/2);

            % computing weights
            atanrhat = f(X);
            w0 = abs(diffCDF(atanrhat,y,atanrhat))*sqrt(2);
            diff = atanrhat+pi/4;
            w= w0 .* sin(diff);
            wp = w(y==1);
            wq = w(y==-1);

            % the second run using (7), with previously calculated weights. 
            obj2 = @(para) rocloss(para, Kp, Kq, wp, wq, 0);
            n = length(y);
            [para, fval2, oflag] = fmincon(obj2, para,[[K';-K'], [ones(n,1);-ones(n,1)]],...
                [pi/2*ones(n,1); zeros(n,1)], [], [], [], [], [], opts);
            % approximated optimal score function.
            f = @(x) para(1:end-1)'*fea(x) + para(end);
            f = @(x) min(max(f(x), 0), pi/2);

            %% caclulating AUCs
            "proposed"
            [~,~,~, AUC_Proposed] = perfcurve(yt, f(Xt), 1);
            AUC_Proposed
            s0 = [s0, AUC_Proposed];

            fX = fea(X);
            % run logistic regression
            obj = @(theta) LogisticObj(theta, fX, y);
            opts = optimoptions("fminunc");
            opts.Display = "none";
            theta1 = fminunc(obj, zeros(dn+1,1),opts);
            
            fXt = fea(Xt);
            tstar0t = theta1(1:end-1)'*fXt + theta1(end);
            [~,~,~,AUC_Logi] = perfcurve(yt,tstar0t,1);

            "logi"
            AUC_Logi
            s1 = [s1, AUC_Logi];

            % run auc maximization
%             obj = @(theta) aucmax(theta, fX, y, 0);
%             opts = optimoptions("fminunc");
%             opts.Display = "none";
%             opts.MaxIterations = 10000000;
%             opts.MaxFunctionEvaluations = 1000000;
%             theta2 = fminunc(obj, zeros(dn,1), opts);
%             fXt = fea(Xt);
%             [~,~,~,AUC_aucmax] = perfcurve(yt,theta2'*fXt,1);
% 
%             "AUC max"
%             AUC_aucmax
%             s2 = [s2, AUC_aucmax];
            %% SPAUC method
            n_p= 0;
            n_n= 0;
            s_p = zeros(dn, 1); 
            s_n = zeros(dn, 1);
            
            w = zeros(dn, 1);
            n_p = n_p + 1;
            s_p = s_p + fea(xp(:,1));
            
            n_n = n_n + 1;
            s_n = s_n + fea(xq(:,1));
            
            X = fea([xp(:, 2:end), xq(:, 2:end)]);
            y = [ones(1, np-1) -ones(1,nq-1)];
            
            tic
            for iter = 1:5000
                for i = 1:size(X,2)
                    if(iter == 1)
                        p_t = n_p ./ (i+2);
                    end
            
                    if(y(i) == 1)
                        n_p = n_p + 1;
                        s_p = s_p + X(:,i);
                    else
                        n_n = n_n + 1;
                        s_n = s_n + X(:,i);
                    end
            
                    ut = s_p / n_p;
                    vt = s_n / n_n;
            
                    if(y(i) == 1)
                        t1 = 2*(1-p_t)*(X(:,i)-ut)*(X(:,i)-ut)'*w;
                    else
                        t1 = 2*p_t*(X(:,i)-vt)*(X(:,i)-vt)'*w;
                    end
            
                    t2 = 2*p_t*(1-p_t)*(vt - ut) + 2*p_t*(1-p_t)*(vt - ut)*(vt - ut)'*w;
            
            
                    F_prime = t1 + t2;
            
                    w = w - .01*F_prime;
            
            %         w
            
                end
            end
            toc
            
            fXt = fea(Xt);
            [~,~,~,AUC_aucmax] = perfcurve(yt,w'*fXt,1);
            
            "AUC max"
            AUC_aucmax
            s2 = [s2, AUC_aucmax];
        end
        save(sprintf("data/cifar10_class%d_np%d.mat", digit, np));
    end
end
%% difference between two CDFs evaluated on rt
function [diff] = diffCDF(r, y, rt)
np = sum(y== 1);
nq = sum(y==-1);
rp = r(y==1);
rq = r(y==-1);

diff = sum(rp' < rt,1)/np - sum(rq' < rt,1)/nq;

end

%% (7) objective, with weights
function [f, g] = rocloss(para, Kp, Kq, wp, wq, lambda)
theta = para;
b = size(Kp,1);

f = mean(wp.*sin(theta' * Kp)) + ...
    mean(wq.*cos(theta' * Kq));

f = -f + lambda*theta(1:b-1)'*theta(1:b-1);

g = mean(wp.*cos(theta' * Kp).*Kp,2) - ...
    mean(wq.*sin(theta' * Kq).*Kq,2);

g = -g + 2*lambda.*[theta(1:b-1);0];
end

%% Logistic Regression objective
function [o] = LogisticObj(para, X, y)
alpha = para(1:end-1);
b = para(end);

m = exp(y.*(alpha'*X+b))./ (exp(alpha'*X+b) + exp(-alpha'*X-b));
o = -mean(log(m),2);
end

%% AUC maximization objective
function [f] = aucmax(theta, X, y, lambda)
xp = X(:, y==1);
xq = X(:, y==-1);

np = sum(y==1);
nq = sum(y==-1);
sq = @(x) (1-x).^2;
f = sum(sum((sq((theta'*xp)' - theta'*xq))))/np/nq + lambda*theta'*theta;

end
