package adams.data.instancesanalysis.pls;

import adams.core.base.BaseRegExp;
import gnu.trove.list.TIntList;
import gnu.trove.list.array.TIntArrayList;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import weka.core.Attribute;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.Center;
import weka.filters.unsupervised.attribute.ReplaceMissingValues;
import weka.filters.unsupervised.attribute.Standardize;

/* loaded from: input_file:adams/data/instancesanalysis/pls/AbstractMultiClassPLS.class */
public abstract class AbstractMultiClassPLS extends AbstractPLS {
    private static final long serialVersionUID = 5649007256147616278L;
    public static final String PARAM_CLASSVALUES = "classValues";
    protected BaseRegExp m_ClassAttributes = getDefaultClassAttributes();
    protected Filter m_Missing;
    protected Filter m_Filter;
    protected TIntList m_ClassAttributeIndices;
    protected Map<Integer, Double> m_ClassMean;
    protected Map<Integer, Double> m_ClassStdDev;

    @Override // adams.data.instancesanalysis.pls.AbstractPLS
    public void reset() {
        super.reset();
        this.m_Missing = null;
        this.m_Filter = null;
        this.m_ClassAttributeIndices = null;
    }

    @Override // adams.data.instancesanalysis.pls.AbstractPLS
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("class-attributes", "classAttributes", new BaseRegExp(""));
    }

    protected BaseRegExp getDefaultClassAttributes() {
        return new BaseRegExp("");
    }

    public void setClassAttributes(BaseRegExp baseRegExp) {
        this.m_ClassAttributes = baseRegExp;
        reset();
    }

    public BaseRegExp getClassAttributes() {
        return this.m_ClassAttributes;
    }

    public String classAttributesTipText() {
        return "The regular expression for identifying the class attributes (besides an explicitly set one).";
    }

    @Override // adams.data.instancesanalysis.pls.AbstractPLS
    public Instances determineOutputFormat(Instances instances) throws Exception {
        this.m_ClassAttributeIndices = new TIntArrayList();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < instances.numAttributes(); i++) {
            if (this.m_ClassAttributes.isMatch(instances.attribute(i).name())) {
                arrayList.add(instances.attribute(i).name());
                this.m_ClassAttributeIndices.add(i);
            }
        }
        if (!arrayList.contains(instances.classAttribute().name())) {
            arrayList.add(instances.classAttribute().name());
            this.m_ClassAttributeIndices.add(instances.classAttribute().index());
        }
        ArrayList arrayList2 = new ArrayList();
        String simpleName = getClass().getSimpleName();
        for (int i2 = 0; i2 < getNumComponents(); i2++) {
            arrayList2.add(new Attribute(simpleName + "_" + (i2 + 1)));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            arrayList2.add(new Attribute((String) it.next()));
        }
        Instances instances2 = new Instances(simpleName, arrayList2, 0);
        instances2.setClassIndex(instances2.numAttributes() - 1);
        this.m_OutputFormat = instances2;
        return instances2;
    }

    @Override // adams.data.instancesanalysis.pls.AbstractPLS
    protected Instances preTransform(Instances instances, Map<String, Object> map) throws Exception {
        HashMap hashMap;
        switch (this.m_PredictionType) {
            case ALL:
                hashMap = null;
                break;
            default:
                hashMap = new HashMap();
                for (int i = 0; i < this.m_ClassAttributeIndices.size(); i++) {
                    int i2 = this.m_ClassAttributeIndices.get(i);
                    hashMap.put(Integer.valueOf(i2), instances.attributeToDoubleArray(i2));
                }
                break;
        }
        if (hashMap != null) {
            map.put("classValues", hashMap);
        }
        if (!isInitialized()) {
            if (this.m_ReplaceMissing) {
                this.m_Missing = new ReplaceMissingValues();
                this.m_Missing.setInputFormat(instances);
            } else {
                this.m_Missing = null;
            }
            this.m_ClassMean = new HashMap();
            this.m_ClassStdDev = new HashMap();
            for (int i3 = 0; i3 < this.m_ClassAttributeIndices.size(); i3++) {
                int i4 = this.m_ClassAttributeIndices.get(i3);
                switch (this.m_PreprocessingType) {
                    case CENTER:
                        this.m_ClassMean.put(Integer.valueOf(i4), Double.valueOf(instances.meanOrMode(i4)));
                        this.m_ClassStdDev.put(Integer.valueOf(i4), Double.valueOf(1.0d));
                        this.m_Filter = new Center();
                        this.m_Filter.setIgnoreClass(true);
                        break;
                    case STANDARDIZE:
                        this.m_ClassMean.put(Integer.valueOf(i4), Double.valueOf(instances.meanOrMode(i4)));
                        this.m_ClassStdDev.put(Integer.valueOf(i4), Double.valueOf(StrictMath.sqrt(instances.variance(i4))));
                        this.m_Filter = new Standardize();
                        this.m_Filter.setIgnoreClass(true);
                        break;
                    case NONE:
                        this.m_ClassMean.put(Integer.valueOf(i4), Double.valueOf(0.0d));
                        this.m_ClassStdDev.put(Integer.valueOf(i4), Double.valueOf(1.0d));
                        this.m_Filter = null;
                        break;
                    default:
                        throw new IllegalStateException("Unhandled preprocessing type; " + this.m_PreprocessingType);
                }
            }
            if (this.m_Filter != null) {
                this.m_Filter.setInputFormat(instances);
            }
        }
        if (this.m_Missing != null) {
            instances = Filter.useFilter(instances, this.m_Missing);
        }
        if (this.m_Filter != null) {
            instances = Filter.useFilter(instances, this.m_Filter);
        }
        return instances;
    }

    @Override // adams.data.instancesanalysis.pls.AbstractPLS
    protected Instances postTransform(Instances instances, Map<String, Object> map) throws Exception {
        Map map2 = (Map) map.get("classValues");
        for (int i = 0; i < this.m_ClassAttributeIndices.size(); i++) {
            int i2 = this.m_ClassAttributeIndices.get(i);
            for (int i3 = 0; i3 < instances.numInstances(); i3++) {
                if (map2 != null) {
                    instances.instance(i3).setClassValue(((double[]) map2.get(Integer.valueOf(i2)))[i3]);
                } else {
                    instances.instance(i3).setClassValue((instances.instance(i3).classValue() * this.m_ClassStdDev.get(Integer.valueOf(i2)).doubleValue()) + this.m_ClassMean.get(Integer.valueOf(i2)).doubleValue());
                }
            }
        }
        return instances;
    }
}
