package ceka.IWBVT;

import weka.classifiers.Classifier;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.matrix.Matrix;

public class LinearRegression extends Classifier{
	//ݳԱ
	private double[] m_Wb;				//
	
	private Instances m_Instances;		//ʵ
	private int m_NumAtts;				//Ը
	private int m_NumInstances;			//ʵ

	//ѵ
	public void buildClassifier(Instances train) throws Exception {
		m_Instances = new Instances(train);
		m_NumAtts = m_Instances.numAttributes();
		m_NumInstances = m_Instances.numInstances();
		m_Wb = new double[m_NumAtts];
		//󸳳ֵ
		Matrix Matrix_X = new Matrix(m_NumInstances, m_NumAtts);
		Matrix Matrix_Y = new Matrix(m_NumInstances,1);
		for(int i=0;i<m_NumInstances;i++) {
			Matrix_Y.set(i, 0, m_Instances.instance(i).classValue());
			for(int j=0;j<m_NumAtts-1;j++) {
				Matrix_X.set(i, j, m_Instances.instance(i).value(j));
			}
			Matrix_X.set(i, m_NumAtts-1, 1);
		}
		//СԻع
	    boolean success = true;
	    double ridge = 0.1;
	    Matrix solution = new Matrix(m_NumAtts, 1);
	    do {
	      Matrix ss = Matrix_X.transpose().times(Matrix_X);
	      // Խ߼һֵ֤
	      for (int i = 0; i < m_NumAtts; i++)
	        ss.set(i, i, ss.get(i, i) + ridge);
	      Matrix bb = Matrix_X.transpose().times(Matrix_Y);
	      try {
	    	solution = ss.solve(bb);
	        success = true;
	      } 
	      catch (Exception ex) {
	        ridge *= 10;
	        success = false;
	      }
	    } while (!success);
		for(int i=0;i<m_NumAtts;i++) {
			m_Wb[i] = solution.get(i, 0);
		}
	}
	
	//Ԥ⺯
	public double classifyInstance(Instance instance) throws Exception {
		double temp = 0;
		for(int i=0;i<m_NumAtts-1;i++) {
			temp += m_Wb[i] * instance.value(i);
		}
		temp += m_Wb[m_NumAtts-1];
		return temp;
	}
	
	//
	public static void main(String argv[]) {
		runClassifier(new LinearRegression(), argv);
	}
}