package com.elphel.imagej.common;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;

import com.elphel.imagej.tileprocessor.ImageDtt;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import ij.ImagePlus;

public class CholeskyLLTMulti {
	public final int n;
	public final double [][] L;
	public int solve_step = 1; //0;
	public double thread_scale = 10; // 2.5; // split jobs in thread_scale * num_treads chunks

	public Matrix getL () {
		return new Matrix(L,n,n);
	}

	public CholeskyLLTMulti(Matrix matA, int mode) {
		n = matA.getRowDimension();
		L = new double[n][n];
		switch (mode) {
		case 0: CholeskyLLTMulti_single(matA); break;
		case 1: CholeskyLLTMulti_multi(matA);  break;
		case 2: CholeskyLLTMulti_fast(matA);   break;
		default:CholeskyLLTMulti_single(matA);
		}
	}	
	
	// Cholesky-Banachiewicz Algorithm ?
	public double [][] CholeskyLLTMulti_single (Matrix mA) {
	      double[][] a = mA.getArray();
	      // Main loop.
	      for (int j = 0; j < n; j++) {
	         double[] Lrowj = L[j];
	         double d = 0.0;
	         for (int k = 0; k < j; k++) {
	            double[] Lrowk = L[k];
	            double s = 0.0;
	            for (int i = 0; i < k; i++) {
	               s += Lrowk[i]*Lrowj[i];
	            }
	            Lrowj[k] = s = (a[j][k] - s)/L[k][k];
	            d = d + s*s;
	         }
	         d = a[j][j] - d;
	         L[j][j] = Math.sqrt(Math.max(d,0.0));
//	         for (int k = j+1; k < n; k++) {
//	            L[j][k] = 0.0;
//	         }
	      }
	      return L;
	}

	public double [][] CholeskyLLTMulti_multi (Matrix mA) {
		double[][] a = mA.getArray();
		final Thread[] threads =       ImageDtt.newThreadArray();
		final AtomicInteger ai =       new AtomicInteger(0);
		final AtomicInteger ati =      new AtomicInteger(0);
		final int threads_chunk = (int) (n / (thread_scale * threads.length));
		final double [] threads_d = new double [threads.length];

		// Main loop.
		for (int j = 0; j < n; j++) {
			final int fj = j;
			final double[] Lrowj = L[j];
			double d = 0.0;
			if (j > threads_chunk) {
				ai.set(0);
				ati.set(0);
				Arrays.fill(threads_d, 0.0);
				for (int ithread = 0; ithread < threads.length; ithread++) { // first sum for pairs
					threads[ithread] = new Thread() {
						public void run() {
							int nthread =  ati.getAndIncrement();
							for (int k0 = ai.getAndAdd(threads_chunk); k0 < fj; k0 = ai.getAndAdd(threads_chunk)) {
								int k1 = Math.min(fj, k0+threads_chunk);
								for (int k = k0; k < k1; k++) {
									double[] Lrowk = L[k];
									double s = 0.0;
									for (int i = 0; i < k; i++) {
										s += Lrowk[i]*Lrowj[i];
									}
									s = (a[fj][k] - s)/L[k][k];
									Lrowj[k] = s;
									threads_d[nthread] += s * s;
								}
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
				for (double dt:threads_d) d+=dt;
			} else {
				for (int k = 0; k < j; k++) {
					double[] Lrowk = L[k];
					double s = 0.0;
					for (int i = 0; i < k; i++) {
						s += Lrowk[i]*Lrowj[i];
					}
					Lrowj[k] = s = (a[j][k] - s)/L[k][k];
					d = d + s*s;
				}
			}
			d = a[j][j] - d;
			L[j][j] = Math.sqrt(Math.max(d,0.0));
//			for (int k = j+1; k < n; k++) {
//				L[j][k] = 0.0;
//			}
		}
		return a;
	}
	public double [][] CholeskyLLTMulti_fast (Matrix mA) {
		double[][] a = mA.getArray();
		final Thread[] threads =       ImageDtt.newThreadArray();
		final AtomicInteger ai =       new AtomicInteger(0);
		final AtomicInteger ati =      new AtomicInteger(0);
		final int threads_chunk = (int) (n / (thread_scale * threads.length));
		final double [] threads_d = new double [threads.length];

		// Main loop.
		for (int j = 0; j < n; j++) {
			final int fj = j;
			final double[] Lrowj = L[j];
			double d = 0.0;
			// WRONG! parallel only for the inner-most (i), because k should progress consequently
			if (j > threads_chunk) {
				ai.set(0);
				ati.set(0);
				Arrays.fill(threads_d, 0.0);
				for (int ithread = 0; ithread < threads.length; ithread++) { // first sum for pairs
					threads[ithread] = new Thread() {
						public void run() {
							int nthread =  ati.getAndIncrement();
//							double[] localLrowj = Lrowj.clone(); 
							for (int k0 = ai.getAndAdd(threads_chunk); k0 < fj; k0 = ai.getAndAdd(threads_chunk)) {
								int k1 = Math.min(fj, k0+threads_chunk);
								for (int k = k0; k < k1; k++) {
									double[] Lrowk = L[k];
									double s = 0.0;
									for (int i = 0; i < k; i++) {
										s += Lrowk[i]*Lrowj[i];
									}
									s = (a[fj][k] - s)/L[k][k];
									Lrowj[k] = s;     // write ***
									threads_d[nthread] += s * s;
								}
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
				for (double dt:threads_d) d+=dt;
			} else {
				for (int k = 0; k < j; k++) {
					double[] Lrowk = L[k];
					double s = 0.0;
					for (int i = 0; i < k; i++) {
						s += Lrowk[i]*Lrowj[i];
					}
					Lrowj[k] = s = (a[j][k] - s)/L[k][k];
					d = d + s*s;
				}
			}
			d = a[j][j] - d;
			L[j][j] = Math.sqrt(Math.max(d,0.0)); // write ***
//			for (int k = j+1; k < n; k++) {
//				L[j][k] = 0.0;
//			}
		}
		return a;
	}



	public Matrix solve (Matrix B, int mode) {
		switch (mode) {
		case 0: return solve_single(B); 
		//		case 1: return solve_multi (B);
		//		case 2: return solve_fast (B);
		default:return solve_single(B);
		}

	}

	/** Solve A*X = B
	   @param  B   A Matrix with as many rows as A and any number of columns.
	   @return     X so that L*L'*X = B
	   @exception  IllegalArgumentException  Matrix row dimensions must agree.
	   @exception  RuntimeException  Matrix is not symmetric positive definite.
	 */

	public Matrix solve_single (Matrix B) {
		if (B.getRowDimension() != n) {
			throw new IllegalArgumentException("Matrix row dimensions must agree.");
		}
		// Copy right hand side.
		double[] x = B.getColumnPackedCopy (); //  (for single-column)
		// Solve L*Y = B;
		for (int k = 0; k < n; k++) {
			for (int i = 0; i < k ; i++) {
				x[k] -= x[i]*L[k][i];
			}
			x[k] /= L[k][k];
		}
		// Solve L'*X = Y;
		for (int k = n-1; k >= 0; k--) {
			for (int i = k+1; i < n ; i++) {
				x[k] -= x[i]*L[i][k];
			}
			x[k] /= L[k][k];
		}
		return new Matrix(x,n);
	}

	public static Matrix solve_single (Matrix B, Matrix L_mat) {
		double [][] L = L_mat.getArray();
		return solve_single (B, L);
	}
	
	public static Matrix solve_single (Matrix B, double [][] L) {
		int n = L.length;
		if (B.getRowDimension() != n) {
			throw new IllegalArgumentException("Matrix row dimensions must agree.");
		}
		// Copy right hand side.
		double[] x = B.getColumnPackedCopy (); //  (for single-column)
		// Solve L*Y = B;
		for (int k = 0; k < n; k++) {
			for (int i = 0; i < k ; i++) {
				x[k] -= x[i]*L[k][i];
			}
			x[k] /= L[k][k];
		}
		// Solve L'*X = Y;
		for (int k = n-1; k >= 0; k--) {
			for (int i = k+1; i < n ; i++) {
				x[k] -= x[i]*L[i][k];
			}
			x[k] /= L[k][k];
		}
		return new Matrix(x,n);
	}
	
	public static void testCholesky(ImagePlus imp_src) {
		float [] fpixels = (float[]) imp_src.getProcessor().getPixels();
		int      width = imp_src.getWidth();
		int      height = imp_src.getHeight();
		int n = height;
		double [][] a = new double [n][n];
		double [][] b = new double [n][1];
		for (int i = 0; i < n; i++) {
			for (int j= 0; j < n; j++) {
				a[i][j] = fpixels[i*width+j];
			}
			b[i][0] = fpixels[i*width+n];
		}
		testCholesky(
				new Matrix(a), // Matrix wjtjlambda,
				new Matrix(b), // Matrix jty)
				imp_src.getTitle()); // String title);
	}

	public static Matrix[] testCholesky(
			Matrix wjtjlambda_in,
			Matrix jty_in,
			String title) {
		int block_size = 100; // 70;//70~best //  64; // 60; // 70; // 80; // 120; // 150; // 200; // 100; // 4; // 10; // 100; // 28;
		boolean truncate = false;
		int trunc_size = 199; // 0 to use full size
		Matrix wjtjlambda,jty;
		if (truncate) {
			int tr_size = (trunc_size==0)?(((wjtjlambda_in.getRowDimension())/block_size) * block_size):trunc_size;
			wjtjlambda=wjtjlambda_in.getMatrix (
					0, // int i0,
					tr_size-1, // int i1,
					0, // int j0,
					tr_size-1); //int j1)
			jty=jty_in.getMatrix (
					0, // int i0,
					tr_size-1, // int i1,
					0, // int j0,
					0); //int j1)
		} else {
			wjtjlambda=wjtjlambda_in;
			jty = jty_in;
		}
		String dbg_title=title+"ch_diff_choleskyBlock-choleskyDecomposition-"+block_size+(truncate?("-truncate"+trunc_size):"");
		Matrix wjtjlambda_copy0 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy1 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy2 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy = new Matrix(wjtjlambda.getArrayCopy());
		double [] starts = new double[8];
		double start_time = (((double) System.nanoTime()) * 1E-9);

		CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(wjtjlambda_copy);
		starts[0] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLLTMulti choleskyLLTMulti_single = new CholeskyLLTMulti(wjtjlambda_copy0,0);
		starts[1] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyBlock choleskyBlock  = new CholeskyBlock(wjtjlambda_copy1.getArray(),block_size);		
//		starts[2] = (((double) System.nanoTime()) * 1E-9) - start_time;
//		choleskyBlock.choleskyBlockMulti();
		starts[3] = (((double) System.nanoTime()) * 1E-9) - start_time;
		double [][] LTriangle = choleskyBlock.getL().getArray(); // get_LTriangle();
		Matrix mdelta_cholesky = choleskyDecomposition.solve(jty);
		starts[4] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_single = CholeskyLLTMulti.solve_single(jty, choleskyLLTMulti_single.getL());
		starts[5] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_multi =  CholeskyLLTMulti.solve_single(jty, LTriangle);
		starts[6] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_block =  choleskyBlock.solve(jty);
		starts[7] = (((double) System.nanoTime()) * 1E-9) - start_time;

		
		System.out.println("testCholesky(): block_size=             "+block_size);
		System.out.println("testCholesky(): choleskyDecomposition:  "+(starts[0])+" sec");
		System.out.println("testCholesky(): choleskyLLTMulti_single:"+(starts[1]-starts[0])+" sec");
		
		System.out.println("testCholesky(): CholeskyBlock():        "+(starts[3]-starts[1])+" sec");
//		System.out.println("testCholesky(): choleskyBlockMulti():   "+(starts[3]-starts[2])+" sec");
		System.out.println("testCholesky(): get_LTriangle():        "+(starts[4]-starts[3])+" sec");
		System.out.println("testCholesky(): solve_single(,single):  "+(starts[5]-starts[4])+" sec");
		System.out.println("testCholesky(): solve_single(,block):   "+(starts[6]-starts[5])+" sec");
		System.out.println("testCholesky(): block.solve():          "+(starts[7]-starts[6])+" sec");
		System.out.println("testCholesky(): title=                  "+title);
		System.out.println("testCholesky(): dbg_title=              "+dbg_title);

		Matrix ch_diff =      choleskyLLTMulti_single.getL().minus(choleskyDecomposition.getL());
		Matrix ch_diff_fast = choleskyBlock.getL().minus(choleskyDecomposition.getL());
		double [][] dbg_img = {
				choleskyLLTMulti_single.getL().getRowPackedCopy(),
				choleskyBlock.getL().getRowPackedCopy(),
				choleskyDecomposition.getL().getRowPackedCopy(),
				ch_diff.getRowPackedCopy(),
				ch_diff_fast.getRowPackedCopy()};
		String[] dbg_titles = {"choleskyLLTMulti_single","choleskyBlock","choleskyDecomposition",
				"choleskyLLTMulti_single-choleskyDecomposition","choleskyBlock-choleskyDecomposition"};
		ShowDoubleFloatArrays.showArrays(
				dbg_img, // double[] pixels,
				ch_diff.getRowDimension(), // int width,
				ch_diff.getRowDimension(), // int height,
				true,
				dbg_title, // String title)
				dbg_titles);
		return new Matrix[] {mdelta_cholesky, mdelta_cholesky_single, mdelta_cholesky_multi};
	}

	public static Matrix[] testCholesky0(
			Matrix wjtjlambda,
			Matrix jty) {
		int block_size = 128;
		boolean truncate = true;
	
		Matrix wjtjlambda_copy0 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy1 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy2 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy = new Matrix(wjtjlambda.getArrayCopy());
		double [] starts = new double[8];
		double start_time = (((double) System.nanoTime()) * 1E-9);

		CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(wjtjlambda_copy);
		starts[0] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLLTMulti choleskyLLTMulti_single = new CholeskyLLTMulti(wjtjlambda_copy0,0);
		starts[1] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLLTMulti choleskyLLTMulti_multi = new CholeskyLLTMulti(wjtjlambda_copy1,1);
		starts[2] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLLTMulti choleskyLLTMulti_fast = new CholeskyLLTMulti(wjtjlambda_copy2,2);
		starts[3] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLLTMulti choleskyLLTMulti = choleskyLLTMulti_multi; // fast;
		Matrix mdelta_cholesky = choleskyDecomposition.solve(jty);
		starts[4] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_single = choleskyLLTMulti.solve(jty, 0);
		starts[5] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_multi =  choleskyLLTMulti.solve(jty, 1);
		starts[6] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_fast =  choleskyLLTMulti.solve(jty, 2);
		starts[7] = (((double) System.nanoTime()) * 1E-9) - start_time;
		
		
		System.out.println("testCholesky(): choleskyDecomposition:  "+(starts[0])+" sec");
		System.out.println("testCholesky(): choleskyLLTMulti_single:"+(starts[1]-starts[0])+" sec");
		
		System.out.println("testCholesky(): choleskyLLTMulti_multi: "+(starts[2]-starts[1])+" sec");
		System.out.println("testCholesky(): choleskyLLTMulti_fast:  "+(starts[3]-starts[2])+" sec");

		System.out.println("testCholesky(): mdelta_cholesky:        "+(starts[4]-starts[3])+" sec");
		System.out.println("testCholesky(): mdelta_cholesky_single: "+(starts[5]-starts[4])+" sec");
		System.out.println("testCholesky(): mdelta_cholesky_multi:  "+(starts[6]-starts[5])+" sec");
		System.out.println("testCholesky(): mdelta_cholesky_fast:   "+(starts[7]-starts[6])+" sec");
		Matrix ch_diff =      choleskyLLTMulti.getL().minus(choleskyDecomposition.getL());
		Matrix ch_diff_fast = choleskyLLTMulti_fast.getL().minus(choleskyDecomposition.getL());
		double [][] dbg_img = {
				choleskyLLTMulti.getL().getRowPackedCopy(),
				choleskyLLTMulti_fast.getL().getRowPackedCopy(),
				choleskyDecomposition.getL().getRowPackedCopy(),
				ch_diff.getRowPackedCopy(),
				ch_diff_fast.getRowPackedCopy()};
		String[] dbg_titles = {"choleskyLLTMulti","choleskyLLTMulti_fast","choleskyDecomposition",
				"choleskyLLTMulti-choleskyDecomposition","choleskyLLTMulti_fast-choleskyDecomposition"};
		ShowDoubleFloatArrays.showArrays(
				dbg_img, // double[] pixels,
				ch_diff.getRowDimension(), // int width,
				ch_diff.getRowDimension(), // int height,
				true,
				"ch_diff_choleskyLLTMulti-choleskyDecomposition", // String title)
				dbg_titles);
		
		return new Matrix[] {mdelta_cholesky, mdelta_cholesky_single, mdelta_cholesky_multi, mdelta_cholesky_fast};
	}
	
}
