/*
 * Decompiled with CFR 0.152.
 */
package de.bwaldvogel.liblinear;

import de.bwaldvogel.liblinear.Function;
import de.bwaldvogel.liblinear.Linear;
import java.util.concurrent.atomic.AtomicBoolean;

class Tron {
    private final Function fun_obj;
    private final double eps;
    private final int max_iter;
    private final double eps_cg;

    public Tron(Function fun_obj, double eps, int max_iter, double eps_cg) {
        this.fun_obj = fun_obj;
        this.eps = eps;
        this.max_iter = max_iter;
        this.eps_cg = eps_cg;
    }

    void tron(double[] w) {
        int i;
        double eta0 = 1.0E-4;
        double eta1 = 0.25;
        double eta2 = 0.75;
        double sigma1 = 0.25;
        double sigma2 = 0.5;
        double sigma3 = 4.0;
        int n = this.fun_obj.get_nr_variable();
        double delta = 0.0;
        double one = 1.0;
        boolean search = true;
        int iter = 1;
        double[] s = new double[n];
        double[] r = new double[n];
        double[] g = new double[n];
        double alpha_pcg = 0.01;
        double[] M = new double[n];
        double[] w0 = new double[n];
        for (i = 0; i < n; ++i) {
            w0[i] = 0.0;
        }
        this.fun_obj.fun(w0);
        this.fun_obj.grad(w0, g);
        double gnorm0 = Tron.euclideanNorm(g);
        double f = this.fun_obj.fun(w);
        this.fun_obj.grad(w, g);
        double gnorm = Tron.euclideanNorm(g);
        if (gnorm <= this.eps * gnorm0) {
            search = false;
        }
        iter = 1;
        double[] w_new = new double[n];
        AtomicBoolean reach_boundary = new AtomicBoolean();
        while (iter <= this.max_iter && search) {
            this.fun_obj.get_diagH(M);
            for (i = 0; i < n; ++i) {
                M[i] = 1.0 - alpha_pcg + alpha_pcg * M[i];
            }
            if (iter == 1) {
                delta = Math.sqrt(Tron.uTMv(n, g, M, g));
            }
            int cg_iter = this.trpcg(delta, g, M, s, r, reach_boundary);
            System.arraycopy(w, 0, w_new, 0, n);
            Tron.daxpy(one, s, w_new);
            double gs = Tron.dot(g, s);
            double prered = -0.5 * (gs - Tron.dot(s, r));
            double fnew = this.fun_obj.fun(w_new);
            double actred = f - fnew;
            double sMnorm = Math.sqrt(Tron.uTMv(n, s, M, s));
            if (iter == 1) {
                delta = Math.min(delta, sMnorm);
            }
            double alpha = fnew - f - gs <= 0.0 ? sigma3 : Math.max(sigma1, -0.5 * (gs / (fnew - f - gs)));
            delta = actred < eta0 * prered ? Math.min(alpha * sMnorm, sigma2 * delta) : (actred < eta1 * prered ? Math.max(sigma1 * delta, Math.min(alpha * sMnorm, sigma2 * delta)) : (actred < eta2 * prered ? Math.max(sigma1 * delta, Math.min(alpha * sMnorm, sigma3 * delta)) : (reach_boundary.get() ? sigma3 * delta : Math.max(delta, Math.min(alpha * sMnorm, sigma3 * delta)))));
            Linear.info("iter %2d act %5.3e pre %5.3e delta %5.3e f %5.3e |g| %5.3e CG %3d%n", iter, actred, prered, delta, f, gnorm, cg_iter);
            if (actred > eta0 * prered) {
                ++iter;
                System.arraycopy(w_new, 0, w, 0, n);
                f = fnew;
                this.fun_obj.grad(w, g);
                gnorm = Tron.euclideanNorm(g);
                if (gnorm <= this.eps * gnorm0) break;
            }
            if (f < -1.0E32) {
                Linear.info("WARNING: f < -1.0e+32%n");
                break;
            }
            if (prered <= 0.0) {
                Linear.info("WARNING: prered <= 0%n");
                break;
            }
            if (!(Math.abs(actred) <= 1.0E-12 * Math.abs(f)) || !(Math.abs(prered) <= 1.0E-12 * Math.abs(f))) continue;
            Linear.info("WARNING: actred and prered too small%n");
            break;
        }
    }

    private int trpcg(double delta, double[] g, double[] M, double[] s, double[] r, AtomicBoolean reach_boundary) {
        int n = this.fun_obj.get_nr_variable();
        double one = 1.0;
        double[] d = new double[n];
        double[] Hd = new double[n];
        double[] z = new double[n];
        reach_boundary.set(false);
        for (int i = 0; i < n; ++i) {
            s[i] = 0.0;
            r[i] = -g[i];
            z[i] = r[i] / M[i];
            d[i] = z[i];
        }
        double zTr = Tron.dot(z, r);
        double cgtol = this.eps_cg * Math.sqrt(zTr);
        int cg_iter = 0;
        while (!(Math.sqrt(zTr) <= cgtol)) {
            ++cg_iter;
            this.fun_obj.Hv(d, Hd);
            double alpha = zTr / Tron.dot(d, Hd);
            Tron.daxpy(alpha, d, s);
            double sMnorm = Math.sqrt(Tron.uTMv(n, s, M, s));
            if (sMnorm > delta) {
                Linear.info("cg reaches trust region boundary%n");
                reach_boundary.set(true);
                alpha = -alpha;
                Tron.daxpy(alpha, d, s);
                double sTMd = Tron.uTMv(n, s, M, d);
                double sTMs = Tron.uTMv(n, s, M, s);
                double dTMd = Tron.uTMv(n, d, M, d);
                double dsq = delta * delta;
                double rad = Math.sqrt(sTMd * sTMd + dTMd * (dsq - sTMs));
                alpha = sTMd >= 0.0 ? (dsq - sTMs) / (sTMd + rad) : (rad - sTMd) / dTMd;
                Tron.daxpy(alpha, d, s);
                alpha = -alpha;
                Tron.daxpy(alpha, Hd, r);
                break;
            }
            alpha = -alpha;
            Tron.daxpy(alpha, Hd, r);
            for (int i = 0; i < n; ++i) {
                z[i] = r[i] / M[i];
            }
            double znewTrnew = Tron.dot(z, r);
            double beta = znewTrnew / zTr;
            Tron.scale(beta, d);
            Tron.daxpy(one, z, d);
            zTr = znewTrnew;
        }
        return cg_iter;
    }

    private static void daxpy(double constant, double[] vector1, double[] vector2) {
        if (constant == 0.0) {
            return;
        }
        assert (vector1.length == vector2.length);
        for (int i = 0; i < vector1.length; ++i) {
            int n = i;
            vector2[n] = vector2[n] + constant * vector1[i];
        }
    }

    private static double dot(double[] vector1, double[] vector2) {
        double product = 0.0;
        assert (vector1.length == vector2.length);
        for (int i = 0; i < vector1.length; ++i) {
            product += vector1[i] * vector2[i];
        }
        return product;
    }

    private static double euclideanNorm(double[] vector) {
        int n = vector.length;
        if (n < 1) {
            return 0.0;
        }
        if (n == 1) {
            return Math.abs(vector[0]);
        }
        double scale = 0.0;
        double sum = 1.0;
        for (int i = 0; i < n; ++i) {
            double t;
            if (vector[i] == 0.0) continue;
            double abs = Math.abs(vector[i]);
            if (scale < abs) {
                t = scale / abs;
                sum = 1.0 + sum * (t * t);
                scale = abs;
                continue;
            }
            t = abs / scale;
            sum += t * t;
        }
        return scale * Math.sqrt(sum);
    }

    private static void scale(double constant, double[] vector) {
        if (constant == 1.0) {
            return;
        }
        int i = 0;
        while (i < vector.length) {
            int n = i++;
            vector[n] = vector[n] * constant;
        }
    }

    private static double uTMv(int n, double[] u, double[] M, double[] v) {
        int i;
        int m = n - 4;
        double res = 0.0;
        for (i = 0; i < m; i += 5) {
            res += u[i] * M[i] * v[i] + u[i + 1] * M[i + 1] * v[i + 1] + u[i + 2] * M[i + 2] * v[i + 2] + u[i + 3] * M[i + 3] * v[i + 3] + u[i + 4] * M[i + 4] * v[i + 4];
        }
        while (i < n) {
            res += u[i] * M[i] * v[i];
            ++i;
        }
        return res;
    }
}

