/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.linear.learner.loss;

import gov.sandia.cognition.math.Ring;
import gov.sandia.cognition.math.matrix.Matrix;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.mtj.SparseMatrixFactoryMTJ;
import org.apache.log4j.Logger;
import org.openimaj.math.matrix.CFMatrixUtils;
import org.openimaj.ml.linear.learner.loss.LossFunction;

public class MatSquareLossFunction
extends LossFunction {
    Logger logger = Logger.getLogger(MatSquareLossFunction.class);
    private SparseMatrixFactoryMTJ spf = SparseMatrixFactoryMTJ.INSTANCE;

    @Override
    public Matrix gradient(Matrix W) {
        Matrix ret = W.clone();
        if (CFMatrixUtils.containsInfinity((Matrix)this.X)) {
            throw new RuntimeException();
        }
        if (CFMatrixUtils.containsInfinity((Matrix)W)) {
            throw new RuntimeException();
        }
        Matrix resid = CFMatrixUtils.fastdot((Matrix)this.X, (Matrix)W);
        if (CFMatrixUtils.containsInfinity((Matrix)resid)) {
            CFMatrixUtils.fastdot((Matrix)this.X, (Matrix)W);
            throw new RuntimeException();
        }
        if (this.bias != null) {
            resid.plusEquals((Ring)this.bias);
        }
        CFMatrixUtils.fastminusEquals((Matrix)resid, (Matrix)this.Y);
        if (CFMatrixUtils.containsInfinity((Matrix)resid)) {
            throw new RuntimeException();
        }
        for (int t = 0; t < resid.getNumColumns(); ++t) {
            Vector xcol = ((Vector)this.X.getRow(t).scale(resid.getElement(t, t))).clone();
            CFMatrixUtils.fastsetcol((Matrix)ret, (int)t, (Vector)xcol);
        }
        return ret;
    }

    @Override
    public double eval(Matrix W) {
        Matrix resid = null;
        resid = W == null ? this.X.clone() : CFMatrixUtils.fastdot((Matrix)this.X, (Matrix)W);
        Matrix vnobias = resid.clone();
        if (this.bias != null) {
            resid.plusEquals((Ring)this.bias);
        }
        Matrix v = resid.clone();
        resid.minusEquals((Ring)this.Y);
        double retval = 0.0;
        for (int t = 0; t < resid.getNumColumns(); ++t) {
            double loss = resid.getElement(t, t);
            retval += loss * loss;
            this.logger.debug((Object)String.format("yr=%d,y=%3.2f,v=%3.2f,v(no bias)=%2.5f,error=%2.5f,serror=%2.5f", t, this.Y.getElement(t, t), v.getElement(t, t), vnobias.getElement(t, t), loss, loss * loss));
        }
        return retval;
    }

    @Override
    public boolean test_backtrack(Matrix W, Matrix grad, Matrix prox, double eta) {
        double normTmp;
        Matrix fastdotGradTmp;
        double normGradProx;
        Matrix tmp = (Matrix)prox.minus((Ring)W);
        double evalW = this.eval(W);
        double evalProx = this.eval(prox);
        return evalProx <= evalW + (normGradProx = CFMatrixUtils.sum((Matrix)(fastdotGradTmp = CFMatrixUtils.fastdot((Matrix)grad.transpose(), (Matrix)tmp)))) + (normTmp = 0.5 * eta * tmp.normFrobenius());
    }

    @Override
    public boolean isMatrixLoss() {
        return true;
    }
}

