/*
 * SOMTreeClassifierNode.java
 *
 * Created on 2007-06-16, 17:08
 *
 */

package iaik.som;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;

/**
 *
 * @author Gernot WALZL
 */
public class SOMTreeClassifierNode {

    protected SOMClassifier classifier;
    protected ArrayList<SOMTreeClassifierNode> subNodes;
    protected double errorThreshold;
    protected double errorThresholdFactorInc;

    /** Creates a new instance of SOMTreeClassifierNode */
    public SOMTreeClassifierNode(final double[][] trainingData,
            final int[] classes, Object[] labels) {
        classifier = new SOMClassifier(trainingData, classes);
        classifier.setLabels(labels);
        subNodes = new ArrayList<SOMTreeClassifierNode>();
        if (MyMath.max(classes) > 1) {
            errorThreshold = 0.05;
            errorThresholdFactorInc = 2.0;
            double[][] relativeError = null;
            boolean[][] errorMatrix = null;
            boolean[][] mergeLists = null;
            while (incrementErrorThreshold(mergeLists)) {
                relativeError = classifier.calculateRelativeError(
                        trainingData, classes);
                errorMatrix = generateErrorMatrix(relativeError,
                        errorThreshold);
                mergeLists = generateMergeLists(errorMatrix);
            }
            int subNodeCount = mergeLists.length;
            double[][] subTrainingData = null;
            int[] subClasses = null;
            Object[] subLabels = null;
            SOMTreeClassifierNode node = null;
            boolean[] mergeList = null;
            for (int i = 0; i < subNodeCount; i++) {
                mergeList = mergeLists[i];
                subTrainingData = filterTrainingData(trainingData, classes,
                        mergeList);
                subClasses = filterTrainingClasses(classes, mergeList);
                subLabels = filterLabels(labels, mergeList);
                node = new SOMTreeClassifierNode(subTrainingData, subClasses,
                        subLabels);
                subNodes.add(node);
                for (int j = 0; j < mergeList.length; j++) {
                    if (mergeList[j])
                        labels[j] = node;
                }
            }
        }
    }


    protected static boolean[][] generateErrorMatrix(
            final double[][] relativeError, final double errorThreshold) {
        int length = relativeError.length;
        boolean[][] errorMatrix = new boolean[length][length];
        for (int i = 0; i < length; i++) {
            errorMatrix[i][i] = false;
            for (int j = i+1; j < length; j++) {
                if (relativeError[i][j] > errorThreshold ||
                        relativeError[j][i] > errorThreshold) {
                    errorMatrix[i][j] = true;
                    errorMatrix[j][i] = true;
                } else {
                    errorMatrix[i][j] = false;
                    errorMatrix[j][i] = false;
                }
            }
        }
        return errorMatrix;
    }


    protected static void generateMergeListRecursive(boolean[] mergeList,
            final boolean[][] errorMatrix, boolean[] markedRows) {
        int length = mergeList.length;
        boolean callRecursive = false;
        for (int r = 0; r < length; r++) {
            if (mergeList[r] && !markedRows[r]) {
                for (int c = 0; c < length; c++)
                    mergeList[c] = mergeList[c] | errorMatrix[r][c];
                markedRows[r] = true;
                callRecursive = true;
            }
        }
        if (callRecursive)
            generateMergeListRecursive(mergeList, errorMatrix, markedRows);
    }


    protected static boolean[][] generateMergeLists(final boolean[][] errorMatrix) {
        int length = errorMatrix.length;
        boolean[] mergeList = null;
        LinkedList<boolean[]> mergeLists = new LinkedList<boolean[]>();
        boolean[] markedRows = new boolean[length];
        for (int i = 0; i < length; i++)
            markedRows[i] = false;

        boolean foundError = false;
        for (int r = 0; r < length; r++) {
            if (!markedRows[r]) {
                foundError = false;
                for (int c = 0; c < length; c++) {
                    if (errorMatrix[r][c]) {
                        foundError = true;
                        break;
                    }
                }
                if (foundError) {
                    mergeList = new boolean[length];
                    for (int c = 0; c < length; c++)
                        mergeList[c] = errorMatrix[r][c];
                    markedRows[r] = true;
                    generateMergeListRecursive(mergeList,
                            errorMatrix, markedRows);
                    mergeLists.add(mergeList);
                }
            }
        }

        if (mergeLists.size() == 0)
            return null;

        boolean[][] result = new boolean[mergeLists.size()][length];
        int i = 0;
        Iterator<boolean[]> it = mergeLists.iterator();
        while (it.hasNext()) {
            result[i] = it.next();
            i++;
        }
        return result;
    }


    protected boolean incrementErrorThreshold(final boolean[][] mergeLists) {
        if (mergeLists == null)
            return true;

        int count = mergeLists.length;
        int length = mergeLists[0].length;
        boolean increment = true;
        for (int r = 0; r < count; r++) {
            increment = true;
            for (int c = 0; c < length; c++) {
                if (!mergeLists[r][c]) {
                    increment = false;
                    break;
                }
            }
            if (increment) {
                errorThreshold *= errorThresholdFactorInc;
                return increment;
            }
        }
        return false;
    }


    protected static double[][] filterTrainingData(final double[][] trainingData,
            final int[] trainingClasses, final boolean[] mergeList) {
        int inLength = trainingClasses.length;
        int outLength = 0;
        for (int i = 0; i < inLength; i++) {
            if (mergeList[trainingClasses[i]])
                outLength++;
        }
        double[][] outputTrainingData = new double[outLength][];
        int numClasses = mergeList.length;
        int o = 0;
        for (int i = 0; i < inLength; i++) {
            if (mergeList[trainingClasses[i]]) {
                outputTrainingData[o] = trainingData[i];
                o++;
            }
        }
        return outputTrainingData;
    }

    protected static int[] filterTrainingClasses(final int[] trainingClasses,
            final boolean[] mergeList) {
        int inLength = trainingClasses.length;
        int outLength = 0;
        for (int i = 0; i < inLength; i++) {
            if (mergeList[trainingClasses[i]])
                outLength++;
        }
        int[] outputTrainingClasses = new int[outLength];
        int numClasses = mergeList.length;
        int[] classConversion = new int[numClasses];
        int classCount = 0;
        for (int i = 0; i < numClasses; i++) {
            if (mergeList[i]) {
                classConversion[i] = classCount;
                classCount++;
            }
        }
        int o = 0;
        for (int i = 0; i < inLength; i++) {
            if (mergeList[trainingClasses[i]]) {
                outputTrainingClasses[o] = classConversion[trainingClasses[i]];
                o++;
            }
        }
        return outputTrainingClasses;
    }

    protected static Object[] filterLabels(final Object[] labels,
            final boolean[] mergeList) {
        int numClasses = mergeList.length;
        int classCount = 0;
        for (int i = 0; i < numClasses; i++) {
            if (mergeList[i])
                classCount++;
        }
        Object[] outputLabels = new Object[classCount];
        classCount = 0;
        for (int i = 0; i < numClasses; i++) {
            if (mergeList[i]) {
                outputLabels[classCount] = labels[i];
                classCount++;
            }
        }
        return outputLabels;
    }


    public int[] classify(final double[][] testData) {
        int length = testData.length;
        int[] result = new int[length];
        Object[] classified = classifier.classifyLabeled(testData);
        for (int i = 0; i < result.length; i++) {
            if (classified[i] instanceof SOMTreeClassifierNode) {
                result[i] = ((SOMTreeClassifierNode)classified[i]).classify(
                        new double[][]{testData[i]})[0];
            }
            else {
                result[i] = ((Integer)classified[i]).intValue();
            }
        }
        return result;
    }

}
