/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.optimize;

import cc.mallet.optimize.Optimizable;
import cc.mallet.optimize.Optimizer;
import cc.mallet.types.MatrixOps;
import cc.mallet.util.MalletLogger;
import java.text.DecimalFormat;
import java.util.Arrays;
import java.util.logging.Logger;

public class StochasticMetaAscent
implements Optimizer.ByBatches {
    private static Logger logger = MalletLogger.getLogger(StochasticMetaAscent.class.getName());
    private final int MAX_ITER = 200;
    private final double LAMBDA = 1.0;
    private final double TOLERANCE = 0.01;
    private final double EPS = 1.0E-10;
    private double mu = 0.1;
    private int totalIterations = 0;
    private double eta_init = 0.03;
    private boolean useHessian = true;
    private double[] gain;
    private double[] gradientTrace;
    Optimizable.ByBatchGradient maxable = null;

    public StochasticMetaAscent(Optimizable.ByBatchGradient maxable) {
        this.maxable = maxable;
    }

    public void setInitialStep(double step) {
        this.eta_init = step;
    }

    public void setMu(double m) {
        this.mu = m;
    }

    public void setUseHessian(boolean flag) {
        this.useHessian = flag;
    }

    public boolean optimize(int numBatches, int[] batchAssignments) {
        return this.optimize(200, numBatches, batchAssignments);
    }

    public boolean optimize(int numIterations, int numBatches, int[] batchAssignments) {
        int numParameters = this.maxable.getNumParameters();
        double[] parameters = new double[numParameters];
        double[] gradient = new double[numParameters];
        double[] hessianProduct = new double[numParameters];
        if (this.gain == null) {
            System.err.println("StochasticMetaAscent: initialStep=" + this.eta_init + "  metaStep=" + this.mu);
            this.gain = new double[numParameters];
            Arrays.fill(this.gain, this.eta_init);
            this.gradientTrace = new double[numParameters];
        }
        this.maxable.getParameters(parameters);
        for (int iteration = 0; iteration < numIterations; ++iteration) {
            double oldApproxValue = 0.0;
            double approxValue = 0.0;
            for (int batch = 0; batch < numBatches; ++batch) {
                logger.info("Iteration " + (this.totalIterations + iteration) + ", batch " + batch + " of " + numBatches);
                this.maxable.getParameters(parameters);
                double initialValue = this.maxable.getBatchValue(batch, batchAssignments);
                oldApproxValue += initialValue;
                if (Double.isNaN(initialValue)) {
                    throw new IllegalArgumentException("NaN in value computation.  Probably you need to reduce initialStep or metaStep.");
                }
                this.maxable.getBatchValueGradient(gradient, batch, batchAssignments);
                MatrixOps.timesEquals(gradient, -1.0);
                if (this.useHessian) {
                    this.computeHessianProduct(this.maxable, parameters, batch, batchAssignments, gradient, this.gradientTrace, hessianProduct);
                }
                this.reportOnVec("x", parameters);
                this.reportOnVec("step", this.gain);
                this.reportOnVec("grad", gradient);
                this.reportOnVec("trace", this.gradientTrace);
                for (int index = 0; index < numParameters; ++index) {
                    int n = index;
                    this.gain[n] = this.gain[n] * Math.max(0.5, 1.0 - this.mu * gradient[index] * this.gradientTrace[index]);
                    int n2 = index;
                    parameters[n2] = parameters[n2] - this.gain[index] * gradient[index];
                    this.gradientTrace[index] = this.useHessian ? 1.0 * this.gradientTrace[index] - this.gain[index] * (gradient[index] + 1.0 * hessianProduct[index]) : 1.0 * this.gradientTrace[index] - this.gain[index] * (gradient[index] + 1.0 * this.gradientTrace[index]);
                }
                this.maxable.setParameters(parameters);
                double finalValue = this.maxable.getBatchValue(batch, batchAssignments);
                approxValue += finalValue;
                logger.info("StochasticMetaAscent: initial value: " + initialValue + "  final value:" + finalValue);
            }
            logger.info("StochasticMetaDescent: Value at iteration (" + (this.totalIterations + iteration) + ")= " + approxValue);
            if (2.0 * Math.abs(approxValue - oldApproxValue) <= 0.01 * (Math.abs(approxValue) + Math.abs(oldApproxValue) + 1.0E-10)) {
                logger.info("Stochastic Meta Ascent: Value difference " + Math.abs(approxValue - oldApproxValue) + " below " + "tolerance; saying converged.");
                this.totalIterations += iteration;
                return true;
            }
            oldApproxValue = approxValue;
        }
        this.totalIterations += numIterations;
        return false;
    }

    private void reportOnVec(String s, double[] v) {
        DecimalFormat f = new DecimalFormat("0.####");
        System.out.println("StochasticMetaAscent: " + s + ":" + "  min " + f.format(MatrixOps.min(v)) + "  max " + f.format(MatrixOps.max(v)) + "  mean " + f.format(MatrixOps.mean(v)) + "  2norm " + f.format(MatrixOps.twoNorm(v)) + "  abs-norm " + f.format(MatrixOps.absNorm(v)));
    }

    private void computeHessianProduct(Optimizable.ByBatchGradient maxable, double[] parameters, int batchIndex, int[] batchAssignments, double[] currentGradient, double[] vector, double[] result) {
        int numParameters = maxable.getNumParameters();
        double eps = 1.0E-6;
        double[] epsGradient = new double[numParameters];
        double[] oldParameters = new double[numParameters];
        System.arraycopy(parameters, 0, oldParameters, 0, numParameters);
        MatrixOps.plusEquals(parameters, vector, eps);
        maxable.setParameters(parameters);
        maxable.getBatchValueGradient(epsGradient, batchIndex, batchAssignments);
        maxable.setParameters(oldParameters);
        for (int index = 0; index < result.length; ++index) {
            result[index] = (-epsGradient[index] - currentGradient[index]) / eps;
        }
    }
}

