/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.classifier.feature.transform.extractor;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.extractor.CleartkExtractorException;
import org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor;
import org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase;
import org.cleartk.classifier.feature.transform.TransformableFeature;

public class ZeroMeanUnitStddevExtractor<OUTCOME_T>
extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T>
implements SimpleFeatureExtractor {
    private SimpleFeatureExtractor subExtractor;
    private boolean isTrained;
    private Map<String, MeanStddevTuple> meanStddevMap;

    public ZeroMeanUnitStddevExtractor(String name) {
        this(name, null);
    }

    public ZeroMeanUnitStddevExtractor(String name, SimpleFeatureExtractor subExtractor) {
        super(name);
        this.subExtractor = subExtractor;
        this.isTrained = false;
    }

    @Override
    protected Feature transform(Feature feature) {
        String featureName = feature.getName();
        MeanStddevTuple stats = this.meanStddevMap.get(featureName);
        double value = ((Number)feature.getValue()).doubleValue();
        return new Feature("ZMUS_" + featureName, (value - stats.mean) / stats.stddev);
    }

    @Override
    public List<Feature> extract(JCas view, Annotation focusAnnotation) throws CleartkExtractorException {
        List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
        ArrayList<Feature> result = new ArrayList<Feature>();
        if (this.isTrained) {
            for (Feature feature : extracted) {
                result.add(this.transform(feature));
            }
        } else {
            result.add(new TransformableFeature(this.name, extracted));
        }
        return result;
    }

    @Override
    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        HashMap<String, MeanVarianceRunningStat> featureStatsMap = new HashMap<String, MeanVarianceRunningStat>();
        for (Instance<OUTCOME_T> instance : instances) {
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    String featureName = untransformedFeature.getName();
                    Object featureValue = untransformedFeature.getValue();
                    if (featureValue instanceof Number) {
                        MeanVarianceRunningStat stats;
                        if (featureStatsMap.containsKey(featureName)) {
                            stats = (MeanVarianceRunningStat)featureStatsMap.get(featureName);
                        } else {
                            stats = new MeanVarianceRunningStat();
                            featureStatsMap.put(featureName, stats);
                        }
                        stats.add(((Number)featureValue).doubleValue());
                        continue;
                    }
                    throw new IllegalArgumentException("Cannot normalize non-numeric feature values");
                }
            }
        }
        this.meanStddevMap = new HashMap<String, MeanStddevTuple>();
        for (Map.Entry entry : featureStatsMap.entrySet()) {
            MeanVarianceRunningStat stats = (MeanVarianceRunningStat)entry.getValue();
            this.meanStddevMap.put((String)entry.getKey(), new MeanStddevTuple(stats.mean(), stats.stddev()));
        }
        this.isTrained = true;
    }

    @Override
    public void save(URI zmusDataUri) throws IOException {
        File out = new File(zmusDataUri);
        BufferedWriter writer = null;
        writer = new BufferedWriter(new FileWriter(out));
        for (Map.Entry<String, MeanStddevTuple> entry : this.meanStddevMap.entrySet()) {
            MeanStddevTuple tuple = entry.getValue();
            writer.append(String.format(Locale.ROOT, "%s\t%f\t%f\n", entry.getKey(), tuple.mean, tuple.stddev));
        }
        writer.close();
    }

    @Override
    public void load(URI zmusDataUri) throws IOException {
        File in = new File(zmusDataUri);
        BufferedReader reader = null;
        this.meanStddevMap = new HashMap<String, MeanStddevTuple>();
        reader = new BufferedReader(new FileReader(in));
        String line = null;
        while ((line = reader.readLine()) != null) {
            String[] featureMeanStddev = line.split("\\t");
            this.meanStddevMap.put(featureMeanStddev[0], new MeanStddevTuple(Double.parseDouble(featureMeanStddev[1]), Double.parseDouble(featureMeanStddev[2])));
        }
        reader.close();
        this.isTrained = true;
    }

    public static class MeanVarianceRunningStat
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private int numSamples;
        private double meanOld;
        private double meanNew;
        private double varOld;
        private double varNew;

        public MeanVarianceRunningStat() {
            this.clear();
        }

        public void init(int n, double mean, double variance) {
            this.numSamples = n;
            this.meanNew = mean;
            this.varNew = variance;
        }

        public void add(double x) {
            ++this.numSamples;
            if (this.numSamples == 1) {
                this.meanOld = this.meanNew = x;
                this.varOld = 0.0;
            } else {
                this.meanNew = this.meanOld + (x - this.meanOld) / (double)this.numSamples;
                this.varNew = this.varOld + (x - this.meanOld) * (x - this.meanNew);
                this.meanOld = this.meanNew;
                this.varOld = this.varNew;
            }
        }

        public void clear() {
            this.numSamples = 0;
        }

        public int getNumSamples() {
            return this.numSamples;
        }

        public double mean() {
            return this.numSamples > 0 ? this.meanNew : 0.0;
        }

        public double variance() {
            return this.numSamples > 1 ? this.varNew / (double)this.numSamples : 0.0;
        }

        public double stddev() {
            return Math.sqrt(this.variance());
        }

        public double variancePop() {
            return this.numSamples > 1 ? this.varNew / (double)(this.numSamples - 1) : 0.0;
        }

        public double stddevPop() {
            return Math.sqrt(this.variancePop());
        }

        private void writeObject(ObjectOutputStream out) throws IOException {
            out.defaultWriteObject();
            out.writeInt(this.numSamples);
            out.writeDouble(this.meanNew);
            out.writeDouble(this.varNew);
        }

        private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException {
            in.defaultReadObject();
            this.numSamples = in.readInt();
            this.meanOld = this.meanNew = in.readDouble();
            this.varOld = this.varNew = in.readDouble();
        }
    }

    public static class MeanStddevTuple {
        public double mean;
        public double stddev;

        public MeanStddevTuple(double mean, double stddev) {
            this.mean = mean;
            this.stddev = stddev;
        }
    }
}

