package coreset;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
import java.util.List;

import base.Line;
import base.Point;
import base.Utility;
import clust.Clustering;
import clust.Objective;
import clust.WeightedPoint;
import clust.ZObjective;

public final class CoresetKMedian extends Coreset {
    public CoresetKMedian(ArrayList<WeightedPoint> instance) {
        super(instance);
    }

    @Override
    protected ArrayList<WeightedPoint> getCoresetPoint(Segment s) {
        ArrayList<WeightedPoint> res = new ArrayList<WeightedPoint>();
        if (s.T.isEmpty())
        	return res;
        WeightedPoint mean = WeightedPoint.mean(s.toPoints());
        mean.weight = Math.round(mean.weight);
        res.add(mean);
        return res;
    }

    @Override
    protected Coreset1D construct1D(List<WeightedPoint> l, Line line , List<Point> optCenters) {
        return new KMedian1D(l, line, optCenters);
    }

    @Override
    protected Objective getObjective() {
    	return ZObjective.getObjective(1.0);
    }

	private ArrayList<Point> projectForCluster(Point center, List<Point> X, double eps) {
		// int dim = X.get(0).dim;
		ArrayList<Point> dir = new ArrayList<Point>();
		for (int i = 0; i < X.size(); i++)
		{
			dir.add(X.get(i).minus(center).normalize());
		}
    	double obj = 0;
    	for (int i = 0; i < X.size(); i++)
    	{
    		obj += Point.dist(X.get(i), center);
    	}
    	ArrayList<Point> res = new ArrayList<Point>();
    	double[] proj = new double[X.size()];
    	Arrays.fill(proj, Double.MAX_VALUE);
    	
    	HashSet<Integer> dirRemain = new HashSet<Integer>();
    	for (int i = 0; i < dir.size(); i++)
    	{
    		dirRemain.add(i);
    	}
    	
    	int nSample = 100;
    	while (dirRemain.size() > 0)
    	{
    		int[] dirRemainArr = new int[dirRemain.size()];
    		int cnt = 0;
    		for (Integer i : dirRemain)
    		{
    			dirRemainArr[cnt++] = i;
    		}

    		Point minDir = null;
    		double minObj = Double.MAX_VALUE;
    		for (int i = 0; i < nSample; i++)
    		//for (Integer i : dirRemain)
    		{
    			//Point sample = Point.gaussian(dim);
    			Point sample = dir.get(
    					dirRemainArr[Utility.rand.nextInt(dirRemain.size())]
    							);
    			// Point sample = dir.get(i);
    			double sum = 0;
    			for (int j = 0; j < X.size(); j++)
    			{
    				double d= Line.dist(X.get(j), new Line(center, sample));
    				double dd = Math.min(proj[j], d);
    				sum += dd;
    			}
    			
    			if (sum <= minObj)
    			{
    				minObj = sum;
    				minDir = sample;
    			}
    		}
    		
    		double sum = 0;
    		res.add(minDir);
    		for (int i = 0; i < X.size(); i++)
    		{
    			double d = Line.dist(X.get(i), new Line(center, minDir));
    			if (d < proj[i])
    			{
    				proj[i] = d;
    			}
    			sum += proj[i];
    		}
    		for (Integer i : dirRemainArr)
    		{
    			// System.out.println(Point.dist(dir.get(i), minDir));
    			if (Point.dist(dir.get(i), minDir) <= eps)
    			{
    				dirRemain.remove(i);
    			}
    		}
    		System.out.println(sum / obj);
    		if (sum / obj <= eps)
    			break;
    	}
    	return res;
	}
	
	@Override
	protected ArrayList<Line> projectToLines(int k, double eps, List<Point> optCenters)
	{
		ArrayList<Line> res = new ArrayList<Line>();

		List<Point> C = optCenters;
		List<WeightedPoint>[] clusters = Clustering.getClusters(this.instance, C);
		
		for (int i = 0; i < C.size(); i++)
		{
			ArrayList<Point> pnt = this.projectForCluster(C.get(i), WeightedPoint.flatten(clusters[i]), eps);
			for (Point p : pnt)
			{
				res.add(new Line(C.get(i), p));
			}
		}
		return res;
	}
}
