/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.stats;

import com.aliasi.stats.MultivariateConstant;
import com.aliasi.stats.MultivariateDistribution;
import com.aliasi.util.AbstractExternalizable;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class MultivariateEstimator
extends MultivariateDistribution
implements Serializable {
    static final long serialVersionUID = 1171641384366463097L;
    final Map<String, Integer> mLabelToIndex;
    final List<String> mIndexToLabel;
    final List<Long> mIndexToCount;
    long mTotalCount = 0L;
    int mNextIndex = 0;
    static final Long[] EMPTY_LONG_ARRAY = new Long[0];

    public MultivariateEstimator() {
        this(new HashMap<String, Integer>(), new ArrayList<String>(), new ArrayList<Long>());
    }

    private MultivariateEstimator(Map<String, Integer> labelToIndex, List<String> indexToLabel, List<Long> indexToCount) {
        this.mLabelToIndex = labelToIndex;
        this.mIndexToLabel = indexToLabel;
        this.mIndexToCount = indexToCount;
    }

    static void checkLongAddInRange(long a, long b) {
        if (Long.MAX_VALUE - b < a) {
            String msg = "Long addition overflow. a=" + a + " b=" + b;
            throw new IllegalArgumentException(msg);
        }
    }

    public void resetCount(String outcomeLabel) {
        Integer index = this.mLabelToIndex.get(outcomeLabel);
        if (index == null) {
            String msg = "May only reset known outcomes. Found outcome=" + outcomeLabel;
            throw new IllegalArgumentException(msg);
        }
        long currentCount = this.mIndexToCount.get(index);
        this.mTotalCount -= currentCount;
        this.mIndexToCount.set(index, 0L);
    }

    public void train(String outcomeLabel, long increment) {
        if (increment < 1L) {
            String msg = "Increment must be positive. Found increment=" + increment;
            throw new IllegalArgumentException(msg);
        }
        this.mTotalCount += increment;
        Integer indexInteger = this.mLabelToIndex.get(outcomeLabel);
        if (indexInteger == null) {
            int index = this.mNextIndex++;
            this.mLabelToIndex.put(outcomeLabel, index);
            this.mIndexToLabel.add(index, outcomeLabel);
            this.mIndexToCount.add(index, increment);
            return;
        }
        int index = indexInteger;
        long currentCount = this.mIndexToCount.get(index);
        MultivariateEstimator.checkLongAddInRange(currentCount, increment);
        this.mIndexToCount.set(index, currentCount + increment);
    }

    @Override
    public long outcome(String outcomeLabel) {
        Integer outcome = this.mLabelToIndex.get(outcomeLabel);
        return outcome == null ? -1L : outcome.longValue();
    }

    @Override
    public String label(long outcome) {
        if (outcome < 0L || outcome >= (long)this.mNextIndex) {
            String msg = "Outcome must be between 0 and max. Max outcome=" + this.maxOutcome() + " Argument outcome=" + outcome;
            throw new IllegalArgumentException(msg);
        }
        return this.mIndexToLabel.get((int)outcome);
    }

    @Override
    public int numDimensions() {
        return this.mIndexToLabel.size();
    }

    @Override
    public double probability(long outcome) {
        if (outcome < this.minOutcome() || outcome > this.maxOutcome()) {
            return 0.0;
        }
        return (double)this.getCount(outcome) / (double)this.trainingSampleCount();
    }

    public long getCount(long outcome) {
        this.checkOutcome(outcome);
        Long count = this.mIndexToCount.get((int)outcome);
        return count == null ? 0L : count;
    }

    public long getCount(String outcomeLabel) {
        Integer index = this.mLabelToIndex.get(outcomeLabel);
        if (index == null) {
            String msg = "May only count known outcomes by label. Found outcome=" + outcomeLabel;
            throw new IllegalArgumentException(msg);
        }
        return this.getCount(index.longValue());
    }

    public long trainingSampleCount() {
        return this.mTotalCount;
    }

    public void compileTo(ObjectOutput objOut) throws IOException {
        objOut.writeObject(new Externalizer(this));
    }

    static class Externalizer
    extends AbstractExternalizable {
        private static final long serialVersionUID = 2913496935213914118L;
        final MultivariateEstimator mEstimator;

        public Externalizer() {
            this.mEstimator = null;
        }

        public Externalizer(MultivariateEstimator estimator) {
            this.mEstimator = estimator;
        }

        @Override
        public void writeExternal(ObjectOutput out) throws IOException {
            String[] labels = this.mEstimator.mIndexToLabel.toArray(Strings.EMPTY_STRING_ARRAY);
            out.writeObject(labels);
            Long[] counts = this.mEstimator.mIndexToCount.toArray(EMPTY_LONG_ARRAY);
            double totalCount = this.mEstimator.mTotalCount;
            double[] ratios = new double[counts.length];
            int i = 0;
            while (i < ratios.length) {
                ratios[i] = counts[i].doubleValue() / totalCount;
                ++i;
            }
            out.writeObject(ratios);
        }

        @Override
        public Object read(ObjectInput in) throws ClassNotFoundException, IOException {
            String[] labels = (String[])in.readObject();
            double[] ratios = (double[])in.readObject();
            return new MultivariateConstant(ratios, labels);
        }
    }
}

