/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.parameterserver.distributed.training;

import java.lang.reflect.Modifier;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import lombok.NonNull;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.parameterserver.distributed.conf.VoidConfiguration;
import org.nd4j.parameterserver.distributed.logic.Storage;
import org.nd4j.parameterserver.distributed.logic.completion.Clipboard;
import org.nd4j.parameterserver.distributed.messages.TrainingMessage;
import org.nd4j.parameterserver.distributed.training.TrainingDriver;
import org.nd4j.parameterserver.distributed.transport.Transport;
import org.reflections.Reflections;
import org.reflections.scanners.Scanner;

public class TrainerProvider {
    private static final TrainerProvider INSTANCE = new TrainerProvider();
    protected Map<String, TrainingDriver<?>> trainers = new HashMap();
    protected VoidConfiguration voidConfiguration;
    protected Transport transport;
    protected Clipboard clipboard;
    protected Storage storage;

    private TrainerProvider() {
        this.scanClasspath();
    }

    public static TrainerProvider getInstance() {
        return INSTANCE;
    }

    protected void scanClasspath() {
        Reflections reflections = new Reflections("org", new Scanner[0]);
        Set classes = reflections.getSubTypesOf(TrainingDriver.class);
        for (Class clazz : classes) {
            if (clazz.isInterface() || Modifier.isAbstract(clazz.getModifiers())) continue;
            try {
                TrainingDriver driver = (TrainingDriver)clazz.newInstance();
                this.trainers.put(driver.targetMessageClass(), driver);
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        if (this.trainers.size() < 1) {
            throw new ND4JIllegalStateException("No TrainingDrivers were found");
        }
    }

    public void init(@NonNull VoidConfiguration voidConfiguration, @NonNull Transport transport, @NonNull Storage storage, @NonNull Clipboard clipboard) {
        if (voidConfiguration == null) {
            throw new NullPointerException("voidConfiguration");
        }
        if (transport == null) {
            throw new NullPointerException("transport");
        }
        if (storage == null) {
            throw new NullPointerException("storage");
        }
        if (clipboard == null) {
            throw new NullPointerException("clipboard");
        }
        this.voidConfiguration = voidConfiguration;
        this.transport = transport;
        this.clipboard = clipboard;
        this.storage = storage;
        for (TrainingDriver<?> trainer : this.trainers.values()) {
            trainer.init(voidConfiguration, transport, storage, clipboard);
        }
    }

    protected <T extends TrainingMessage> TrainingDriver<T> getTrainer(T message) {
        TrainingDriver<?> driver = this.trainers.get(message.getClass().getSimpleName());
        if (driver == null) {
            throw new ND4JIllegalStateException("Can't find trainer for [" + message.getClass().getSimpleName() + "]");
        }
        return driver;
    }

    public <T extends TrainingMessage> void doTraining(T message) {
        TrainingDriver<T> trainer = this.getTrainer(message);
        trainer.startTraining(message);
    }
}

