/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.classifier.feature.transform.extractor;

import com.google.common.collect.LinkedHashMultiset;
import com.google.common.collect.Multiset;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import org.apache.uima.jcas.JCas;
import org.apache.uima.jcas.tcas.Annotation;
import org.cleartk.classifier.Feature;
import org.cleartk.classifier.Instance;
import org.cleartk.classifier.feature.extractor.CleartkExtractorException;
import org.cleartk.classifier.feature.extractor.simple.SimpleFeatureExtractor;
import org.cleartk.classifier.feature.transform.OneToOneTrainableExtractor_ImplBase;
import org.cleartk.classifier.feature.transform.TransformableFeature;

public class TfidfExtractor<OUTCOME_T>
extends OneToOneTrainableExtractor_ImplBase<OUTCOME_T>
implements SimpleFeatureExtractor {
    protected SimpleFeatureExtractor subExtractor;
    protected boolean isTrained;
    protected IDFMap idfMap;

    public TfidfExtractor(String name) {
        this(name, null);
    }

    public TfidfExtractor(String name, SimpleFeatureExtractor extractor) {
        super(name);
        this.subExtractor = extractor;
        this.isTrained = false;
        this.idfMap = new IDFMap();
    }

    @Override
    protected Feature transform(Feature feature) {
        int tf = (Integer)feature.getValue();
        double tfidf = (double)tf * this.idfMap.getIDF(feature.getName());
        return new Feature("TF-IDF_" + feature.getName(), tfidf);
    }

    @Override
    public List<Feature> extract(JCas view, Annotation focusAnnotation) throws CleartkExtractorException {
        List<Feature> extracted = this.subExtractor.extract(view, focusAnnotation);
        ArrayList<Feature> result = new ArrayList<Feature>();
        if (this.isTrained) {
            for (Feature feature : extracted) {
                result.add(this.transform(feature));
            }
        } else {
            result.add(new TransformableFeature(this.name, extracted));
        }
        return result;
    }

    protected IDFMap createIdfMap(Iterable<Instance<OUTCOME_T>> instances) {
        IDFMap newIdfMap = new IDFMap();
        for (Instance<OUTCOME_T> instance : instances) {
            HashSet<String> featureNames = new HashSet<String>();
            for (Feature feature : instance.getFeatures()) {
                if (!this.isTransformable(feature)) continue;
                for (Feature untransformedFeature : ((TransformableFeature)feature).getFeatures()) {
                    featureNames.add(untransformedFeature.getName());
                }
            }
            for (String featureName : featureNames) {
                newIdfMap.add(featureName);
            }
            newIdfMap.incTotalDocumentCount();
        }
        return newIdfMap;
    }

    @Override
    public void train(Iterable<Instance<OUTCOME_T>> instances) {
        this.idfMap = this.createIdfMap(instances);
        this.isTrained = true;
    }

    @Override
    public void save(URI documentFreqDataURI) throws IOException {
        this.idfMap.save(documentFreqDataURI);
    }

    @Override
    public void load(URI documentFreqDataURI) throws IOException {
        this.idfMap.load(documentFreqDataURI);
        this.isTrained = true;
    }

    protected static class IDFMap {
        private Multiset<String> documentFreqMap = LinkedHashMultiset.create();
        private int totalDocumentCount = 0;

        public void add(String term) {
            this.documentFreqMap.add((Object)term);
        }

        public void incTotalDocumentCount() {
            ++this.totalDocumentCount;
        }

        public int getTotalDocumentCount() {
            return this.totalDocumentCount;
        }

        public int getDF(String term) {
            return this.documentFreqMap.count((Object)term);
        }

        public double getIDF(String term) {
            int df = this.getDF(term);
            return Math.log((this.totalDocumentCount + 1) / (df + 1));
        }

        public void save(URI outputURI) throws IOException {
            File out = new File(outputURI);
            BufferedWriter writer = null;
            writer = new BufferedWriter(new FileWriter(out));
            writer.append(String.format(Locale.ROOT, "NUM DOCUMENTS\t%d\n", this.totalDocumentCount));
            for (Multiset.Entry entry : this.documentFreqMap.entrySet()) {
                writer.append(String.format(Locale.ROOT, "%s\t%d\n", entry.getElement(), entry.getCount()));
            }
            writer.close();
        }

        public void load(URI inputURI) throws IOException {
            File in = new File(inputURI);
            BufferedReader reader = null;
            reader = new BufferedReader(new FileReader(in));
            String firstLine = reader.readLine();
            String[] keyValuePair = firstLine.split("\\t");
            this.totalDocumentCount = Integer.parseInt(keyValuePair[1]);
            String line = null;
            while ((line = reader.readLine()) != null) {
                String[] termFreqPair = line.split("\\t");
                this.documentFreqMap.add((Object)termFreqPair[0], Integer.parseInt(termFreqPair[1]));
            }
            reader.close();
        }
    }
}

