package sfa.classification;

import java.io.File;
import java.io.IOException;
import java.util.Iterator;
import java.util.List;
import junit.framework.TestCase;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import sfa.SFAWordsTest;
import sfa.classification.Classifier;
import sfa.timeseries.TimeSeries;
import sfa.timeseries.TimeSeriesLoader;

@RunWith(JUnit4.class)
/* loaded from: input_file:sfa/classification/AbstractClassifierTest.class */
public abstract class AbstractClassifierTest {
    private static final double DELTA = 0.05d;
    protected static final File DATASETS_DIRECTORY = new File(AbstractClassifierTest.class.getClassLoader().getResource("datasets/univariate/").getFile());

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:sfa/classification/AbstractClassifierTest$DataSet.class */
    public static final class DataSet {
        String name;
        double trainingAccuracy;
        double testingAccuracy;
        double testEarliness;

        public DataSet(String str, double d, double d2) {
            this.name = str;
            this.trainingAccuracy = d;
            this.testingAccuracy = d2;
        }

        public DataSet(String str, double d, double d2, double d3) {
            this(str, d, d2);
            this.testEarliness = d3;
        }
    }

    @Test
    public void testClassificationOnUCRData() {
        SFAWordsTest.class.getClassLoader();
        Iterator<DataSet> it = getDataSets().iterator();
        while (it.hasNext()) {
            TestCase.assertNotNull(trainClassifier(it.next()));
        }
    }

    @Test
    public void testSave() throws IOException {
        testSaveLoadGivesEqualTestResults(getDataSets().get(0));
    }

    private void testSaveLoadGivesEqualTestResults(DataSet dataSet) throws IOException {
        Classifier trainClassifier = trainClassifier(dataSet);
        File createTempClassifierFile = createTempClassifierFile();
        trainClassifier.save(createTempClassifierFile);
        Classifier load = Classifier.load(createTempClassifierFile);
        Assert.assertNotNull(load);
        checkEqualResultsOfClassifiers(dataSet, trainClassifier, load);
    }

    private void checkEqualResultsOfClassifiers(DataSet dataSet, Classifier classifier, Classifier classifier2) {
        TimeSeries[] loadDataset = TimeSeriesLoader.loadDataset(getFirstTrainFile(dataSet));
        Assert.assertArrayEquals(classifier2.score(loadDataset).labels, classifier.score(loadDataset).labels);
        Assert.assertEquals(r0.correct.get(), r0.correct.get());
    }

    private File getFirstTrainFile(DataSet dataSet) {
        return getTrainFiles(dataSet)[0];
    }

    private File createTempClassifierFile() throws IOException {
        File createTempFile = File.createTempFile("classifier", "class");
        createTempFile.deleteOnExit();
        return createTempFile;
    }

    protected Classifier trainClassifier(DataSet dataSet) {
        Classifier classifier = null;
        for (File file : getTrainFiles(dataSet)) {
            File file2 = new File(file.getAbsolutePath().replaceFirst("TRAIN", "TEST"));
            if (!file2.exists()) {
                System.err.println("File " + file2.getName() + " does not exist");
                file2 = null;
            }
            Classifier.DEBUG = false;
            TimeSeries[] loadDataset = TimeSeriesLoader.loadDataset(file2);
            TimeSeries[] loadDataset2 = TimeSeriesLoader.loadDataset(file);
            classifier = initClassifier();
            Classifier.Score eval = classifier.eval(loadDataset2, loadDataset);
            System.out.println(eval.toString());
            TestCase.assertEquals("testing result of " + dataSet.name + " does NOT match", dataSet.testingAccuracy, eval.getTestingAccuracy(), DELTA);
            TestCase.assertEquals("training result of " + dataSet.name + " does NOT match", dataSet.trainingAccuracy, eval.getTrainingAccuracy(), DELTA);
            if (eval.getTestEarliness() != null) {
                TestCase.assertEquals("test earliness result of " + dataSet.name + " does NOT match", dataSet.testEarliness, eval.getTestEarliness().doubleValue(), DELTA);
            }
        }
        return classifier;
    }

    protected File[] getTrainFiles(DataSet dataSet) {
        return getTrainFilesFromDir(new File(DATASETS_DIRECTORY.getAbsolutePath() + "/" + dataSet.name));
    }

    private File[] getTrainFilesFromDir(File file) {
        return file.listFiles(file2 -> {
            return file2.getName().toUpperCase().endsWith("TRAIN");
        });
    }

    protected abstract List<DataSet> getDataSets();

    protected abstract Classifier initClassifier();
}
