/*
 * SOMClassifier.java
 *
 * Created on 2007-06-05, 18:44
 *
 */

package iaik.som;

import java.util.Iterator;
import java.util.LinkedList;

/**
 *
 * @author Gernot WALZL
 */
public class SOMClassifier extends SelfOrganizingMap {

    protected int numClasses;
    protected int[][] myHits;
    protected double[][] normalizedHits;
    protected Object[] labels;

    /** Creates a new instance of SOMClassifier */
    public SOMClassifier(double[][] trainingData, int[] classes) {
        super(trainingData);

        int length = classes.length;  // = trainingData.length
        int dim = trainingData[0].length;
        int codelength = SOMData.codebook.length;
        numClasses = MyMath.max(classes) + 1;
        labels = new Object[numClasses];
        myHits = new int[numClasses][codelength];
        LinkedList<double[]> dataList = new LinkedList<double[]>();
        double[][] sData;
        for (int i = 0; i < numClasses; i++) {
            dataList.clear();
            for (int j = 0; j < length; j++) {
                if (i == classes[j]) {
                    dataList.add(trainingData[j]);
                }
            }
            sData = new double[dataList.size()][dim];
            Iterator<double[]> it = dataList.iterator();
            for (int j = 0; j < dataList.size(); j++) {
                sData[j] = it.next();
            }
            myHits[i] = hits(sData);
        }
        normalizeHits();
    }


    public SOMClassifier(String content) {
        super(content);

        String[] splitted = content.split("som_hits: ");
        String mycontent = splitted[1];

        String[] line = mycontent.split(NEWLINE);
        String[] word;

        numClasses = Integer.valueOf(line[0]).intValue();
        int codelength = line[1].split(" ").length;
        myHits = new int[numClasses][codelength];

        for (int i = 0; i < numClasses; i++) {
            word = line[i+1].split(" ");
            for (int j = 0; j < codelength; j++)
                myHits[i][j] = Integer.valueOf(word[j]).intValue();
        }
        normalizeHits();
    }


    protected void normalizeHits() {
        int codelength = myHits[0].length;
        int sum = 0;

        normalizedHits = new double[numClasses][codelength];
        for (int i = 0; i < numClasses; i++) {
            sum = MyMath.sum(myHits[i]);
            for (int j = 0; j < codelength; j++) {
                normalizedHits[i][j] = (double)myHits[i][j] / sum;
            }
        }
    }


    public double[][] classificationProbability(double[][] testData) {
        int length = testData.length;
        int numclasses = myHits.length;

        double[][] classificationProb = new double[length][numclasses];

        int[] bmus = bmus(testData);
        for (int i = 0; i < length; i++) {
            for (int j = 0; j < numclasses; j++) {
                classificationProb[i][j] = normalizedHits[j][bmus[i]];
            }
        }

        // normalize classification
        double sum = 0;
        for (int i = 0; i < length; i++) {
            sum = MyMath.sum(classificationProb[i]);
            for (int j = 0; j < numclasses; j++) {
                classificationProb[i][j] = classificationProb[i][j] / sum;
            }
        }

        return classificationProb;
    }


    public int[] classify(double[][] testData) {
        int length = testData.length;
        double[][] classificationProb = classificationProbability(testData);
        int[] classification = new int[length];

        for (int i = 0; i < length; i++) {
            classification[i] = MyMath.maxIndex(classificationProb[i]);
        }

        return classification;
    }


    public void setLabels(Object[] labels) {
        this.labels = labels;
    }


    public Object[] classifyLabeled(double[][] testData) {
        int length = testData.length;
        Object[] result = new Object[length];
        int[] classified = classify(testData);
        for (int i = 0; i < length; i++) {
            result[i] = labels[classified[i]];
        }
        return result;
    }


    public int[][] countClassifications(double[][] testData, int[] classes) {
        int length = testData.length;
        int numTestClasses = MyMath.max(classes) + 1;
        int[] classification = classify(testData);
        int[][] classificationCount = new int[numTestClasses][numClasses];
        for (int i = 0; i < numTestClasses; i++) {
            for (int j = 0; j < numClasses; j++) {
                classificationCount[i][j] = 0;
            }
        }
        for (int i = 0; i < length; i++) {
            classificationCount[classes[i]][classification[i]]++;
        }

        return classificationCount;
    }


    public double[][] calculateRelativeError(double[][] testData,
            int[] classes) {
        int[][] classcount = countClassifications(testData, classes);
        int length = classcount.length;
        double[][] relativeError = new double[length][length];
        int sum;
        for (int i = 0; i < length; i++) {
            sum = 0;
            for (int j = 0; j < length; j++) {
                sum += classcount[i][j];
            }
            for (int j = 0; j < length; j++) {
                if (i == j) {
                    relativeError[i][j] = 0.0;
                } else {
                    relativeError[i][j] = (double)classcount[i][j] / sum;
                }
            }
        }
        return relativeError;
    }


    public String toString() {
        StringBuffer content = new StringBuffer();
        content.append(super.toString());
        content.append("som_hits: " + numClasses + NEWLINE);
        String line;
        for (int i = 0; i < myHits.length; i++) {
            line = "";
            for (int j = 0; j < myHits[0].length; j++) {
                if (!line.equals(""))
                    line += " ";
                line += myHits[i][j];
            }
            content.append(line + NEWLINE);
        }
        return content.toString();
    }

}
