package com.elphel.imagej.common;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;

import com.elphel.imagej.tileprocessor.ImageDtt;

import Jama.CholeskyDecomposition;
import Jama.Matrix;
import ij.ImagePlus;

public class CholeskyLDLTMulti {
	public final int n;
	public final double [][] L;
	public int solve_step = 10;
	public double thread_scale = 2.5; // split jobs in thread_scale * num_treads chunks
	
	

	public CholeskyLDLTMulti(Matrix matA) {
		n = matA.getRowDimension();
		L = CholeskyLDLTMulti_multi(matA);
	}
	
	public CholeskyLDLTMulti(Matrix matA, int mode) {
		n = matA.getRowDimension();
		switch (mode) {
		case 0: L = CholeskyLDLTMulti_single(matA); break;
		case 1: L = CholeskyLDLTMulti_multi(matA); break;
		case 2: L = CholeskyLDLTMulti_fast(matA); break;
		default:L = CholeskyLDLTMulti_single(matA);
		}
	}	

	public double[][] CholeskyLDLTMulti_multi(Matrix matA) {
		final double [][] a =          matA.getArray();
		final Thread[] threads =       ImageDtt.newThreadArray();
		final AtomicInteger ai =       new AtomicInteger(0);
		final AtomicInteger ati =      new AtomicInteger(0);
		final double [] ajj_threaded = new double[threads.length];

		final int threads_chunk = (int) (n / (thread_scale * threads.length));
		for (int j = 0; j < n; j++) {
			final int fj = j;
			final double [] aj = a[j];
			double ajj = aj[j];
			if (j > threads_chunk) { // use multithreaded
				Arrays.fill(ajj_threaded, 0.0);
				ai.set(0);
				ati.set(0);
				for (int ithread = 0; ithread < threads.length; ithread++) { // first sum for pairs
					threads[ithread] = new Thread() {
						public void run() {
							int nthread =  ati.getAndIncrement();
							for (int i0 = ai.getAndAdd(threads_chunk); i0 < fj; i0 = ai.getAndAdd(threads_chunk)) {
								int i1 = Math.min(fj, i0+threads_chunk);
								for (int i = i0; i < i1; i++) {
									double aji = aj[i];
									ajj_threaded[nthread] += a[i][i]*aji*aji; // squared
								}
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
				for (int nthread = 0; nthread < ajj_threaded.length; nthread++) {
					ajj -= ajj_threaded[nthread];
				}
			} else {
				for (int i = 0; i < j; i++) {  // use aj, a, i
					double aji = aj[i];
					ajj -= a[i][i]*aji*aji; // squared
				}
			}
			aj[j] = ajj;
			// TODO: above run single-threaded for small, then multithreaded with post-accumulation of ajj
			if ((n - j+1) > threads_chunk) { // use multithreaded
				ai.set(j+1);
	    		for (int ithread = 0; ithread < threads.length; ithread++) {
	    			threads[ithread] = new Thread() {
	    				public void run() {
	    					for (int k0 = ai.getAndAdd(threads_chunk); k0 < n; k0 = ai.getAndAdd(threads_chunk)) {
	    						int k1 = Math.min(n, k0+threads_chunk);
	    						for (int k = k0; k < k1; k++) {
	    							double [] ak = a[k];
	    							for (int i = 0; i < fj; i++) { // ak, aj, i, 
	    								ak[fj] -= a[i][i] * ak[i] * aj[i];
	    							}
	    							ak[fj] /= aj[fj];
	    						}
	    					}
	    				}
	    			};
	    		}		      
	    		ImageDtt.startAndJoin(threads);
			} else { // use single-threaded
			    for (int k = j+1; k < n; k++) {
					final double [] ak = a[k];
			    	for (int i = 0; i < j; i++) { // ak, aj, i, 
			    		ak[j] -= a[i][i]* ak[i] * aj[i];
			    	}
					ak[j] /= ajj;
			    }
			}
		}
//		L = a;
		return a;
	}

	
	
	public double[][] CholeskyLDLTMulti_fast(Matrix matA) {
		final double [][] a =          matA.getArray();
		final Thread[] threads =       ImageDtt.newThreadArray();
		final AtomicInteger ai =       new AtomicInteger(0);
		final AtomicInteger ati =      new AtomicInteger(0);
		final double [] ajj_threaded = new double[threads.length];
		final int threads_chunk = (int) (n / (thread_scale * threads.length));
		
		final Integer [] slots_order = new Integer [threads.length]; // per-thread slot index
		final int [] slot_j  = new int [threads.length];     // slot synchronized to this j (including)
		final double [][][] slots_a = new double [threads.length][n][n]; // copies of L in progress
		Arrays.fill(slot_j, -1);
//		Arrays.setAll(slots_order, i->i);
		
//		for (int i = 0; i < slots_order.length; i++) {
//			slots_order[i] = i;
//		}
		for (int j = 0; j < n; j++) {
			final int fj = j;
			Arrays.setAll(slots_order, i->i); //https://stackoverflow.com/questions/68331291/how-do-i-reorder-another-array-based-on-a-sorted-array-java
			Arrays.sort(slots_order, (la, lb) -> slot_j[lb] - slot_j[la]); // decreasing order
			final double [] aj = a[j];
			double ajj = aj[j];
			if (j > threads_chunk) { // use multithreaded TEMPORARY MAKE IT ALWAYS SINGLE
				Arrays.fill(ajj_threaded, 0.0);
				ai.set(0);
				ati.set(0);
				for (int ithread = 0; ithread < threads.length; ithread++) { // first sum for pairs
					threads[ithread] = new Thread() {
						public void run() {
							int ithread =  ati.getAndIncrement();
							double [][] local_a;
							for (int i0 = ai.getAndAdd(threads_chunk); i0 < fj; i0 = ai.getAndAdd(threads_chunk)) {
								int slot = slots_order[ithread];
								local_a = slots_a[slot];
								int local_j = slot_j[slot];
								if (local_j < fj-1) { // catch up if needed
									for (int col = local_j+1; col < fj; col++){
										for (int row = local_j; row < n; row++) {
											local_a[row][col] = a[row][col]; // copying from global to local memory
										}
									}
									slot_j[slot] = fj-1;
								}
								int i1 = Math.min(fj, i0+threads_chunk);
								double [] local_aj = local_a[fj];
								for (int i = i0; i < i1; i++) {
									double aji = local_aj[i];
									ajj_threaded[slot] += local_a[i][i]*aji*aji; // squared
								}
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
				for (int nthread = 0; nthread < ajj_threaded.length; nthread++) {
					ajj -= ajj_threaded[nthread];
				}
			} else {
				for (int i = 0; i < j; i++) {  // use aj, a, i
					double aji = aj[i];
					ajj -= a[i][i]*aji*aji; // squared
				}
			}
			aj[j] = ajj;
			final double fajj = ajj; // use in all threads, a[j][j] is also valid
			// TODO: above run single-threaded for small, then multithreaded with post-accumulation of ajj
			if ((n - j+1) > threads_chunk) { // use multithreaded
				// re-evaluate slots, they could change after diagonal calculations
				Arrays.setAll(slots_order, i->i); //https://stackoverflow.com/questions/68331291/how-do-i-reorder-another-array-based-on-a-sorted-array-java
				Arrays.sort(slots_order, (la, lb) -> slot_j[lb] - slot_j[la]); // decreasing order
				ai.set(j+1);
				ati.set(0);
	    		for (int ithread = 0; ithread < threads.length; ithread++) {
	    			threads[ithread] = new Thread() {
	    				public void run() {
							int ithread =  ati.getAndIncrement();
							double [][] local_a;
	    					for (int k0 = ai.getAndAdd(threads_chunk); k0 < n; k0 = ai.getAndAdd(threads_chunk)) {
								int slot = slots_order[ithread];
								local_a = slots_a[slot];
								int local_j = slot_j[slot];
								if (local_j < fj-1) { // catch up if needed
									for (int col = local_j+1; col < fj; col++){
										for (int row =col; row < n; row++) {
											local_a[row][col] = a[row][col]; // copying from global to local memory
										}
									}
									slot_j[slot] = fj-1;
								}
								double [] local_aj = local_a[fj];
	    						int k1 = Math.min(n, k0+threads_chunk);
	    						for (int k = k0; k < k1; k++) {
	    							double [] local_ak = local_a[k];
	    							double akj = local_ak[fj];
	    							for (int i = 0; i < fj; i++) { // ak, aj, i, 
//	    								ak[fj] -= a[i][i] * ak[i] * aj[i];
//	    								local_ak[fj] -= local_a[i][i] * local_ak[i] * local_aj[i];
	    								akj -= local_a[i][i] * local_ak[i] * local_aj[i];
	    								
	    							}
//	    							ak[fj] /= aj[fj];
	    							a[k][fj] = local_ak[fj] / fajj; // write to global memory
	    							a[k][fj] = akj / fajj; // write to global memory
	    						}
	    					}
	    				}
	    			};
	    		}		      
	    		ImageDtt.startAndJoin(threads);
			} else { // use single-threaded
			    for (int k = j+1; k < n; k++) {
					final double [] ak = a[k];
			    	for (int i = 0; i < j; i++) { // ak, aj, i, 
			    		ak[j] -= a[i][i]* ak[i] * aj[i];
			    	}
					ak[j] /= ajj;
			    }
			}
		}
//		L = a;
		return a;
	}
	
	public double [][] CholeskyLDLTMulti_single(Matrix matA) { // single-threaded
//		n = matA.getRowDimension();
		// not needed
//		L = matA.getArray(); // will be modified, copy externally if needed
		final double [][] a = matA.getArray();
		for (int j = 0; j < n; j++) {
			for (int i = 0; i < j; i++) {
				a[j][j] -= a[i][i]*a[j][i]*a[j][i]; // squared
			}
		    for (int k = j+1; k < n; k++) {
		    	for (int i = 0; i < j; i++) {
		    		a[k][j] -= a[i][i]* a[k][i] * a[j][i];
		    	}
				a[k][j] /= a[j][j];
		    }
		}
//		L = a;
		return a;
	}


	
	
	public Matrix solve0 (Matrix B) { // multithreaded
		if (B.getRowDimension() != n) {
			throw new IllegalArgumentException("Matrix row dimensions must agree.");
		}
		// Copy right hand side.
		double[] x = B.getColumnPackedCopy (); //  (for single-column)
		// Solve L*D*Y = B;

		for (int k = 0; k < n; k++) {
			for (int i = 0; i < k ; i++) {
				x[k] -= x[i]*L[k][i];
			}
		}
		for (int k = 0; k < n; k++) {
			x[k] /= L[k][k];
		}		
		// Solve L'*X = Y;
		for (int k = n-1; k >= 0; k--) {
			for (int i = k+1; i < n ; i++) {
				x[k] -= x[i]*L[i][k];
			}
		}
		return new Matrix(x,n);
	}
		
	public Matrix solve (Matrix B) {
		if (B.getRowDimension() != n) {
			throw new IllegalArgumentException("Matrix row dimensions must agree.");
		}
		// Copy right hand side.
		//			      double[][] X = B.getArrayCopy();
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		final int threads_chunk = (int) (n / (thread_scale * threads.length));
		double[] x = B.getColumnPackedCopy (); //  (for single-column)
		// Solve L*D*Y = B;
		for (int row0=0; row0 < n; row0 += solve_step) {
			final int frow0 = row0;
			final int frow1 = Math.min(row0 + solve_step, n);
			// filling triangle single-threaded 
			for (int row = frow0; row < frow1; row++) {
				double [] l_row = L[row];
				for (int i = frow0; i < row ; i++) {
					x[row] -= x[i]*l_row[i];
				}
			}
			// Filling rectangle parallel
			if (frow1 < n) {
				ai.set(frow1);
				for (int ithread = 0; ithread < threads.length; ithread++) {
					threads[ithread] = new Thread() {
						public void run() {
							double [] l_row;
							for (int row = ai.getAndAdd(threads_chunk); row < n; row = ai.getAndAdd(threads_chunk)) {
								int row_lim = Math.min(n, row+threads_chunk);
								for (; row < row_lim; row++) {
									l_row = L[row];
									for (int col = frow0; col < frow1; col++) {
										x[row] -= x[col]*l_row[col];
									}
								}
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
			}
		}
//		for (int k = 0; k < n; k++) {
//		for (int i = 0; i < k ; i++) {
//			x[k] -= x[i]*L[k][i];
//		}
//	}
		
		// make parallel
		ai.set(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int row = ai.getAndAdd(threads_chunk); row < n; row = ai.getAndAdd(threads_chunk)) {
						int row_lim = Math.min(n, row+threads_chunk);
						for (; row < row_lim; row++) {
							x[row] /= L[row][row];
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		/*
		for (int row = 0; row < n; row++) {
			x[row] /= L[row][row];
		}
		*/
		
		// Solve L'*X = Y;
		for (int row1 = n-1; row1 > 0 ; row1 -= solve_step) {
			final int frow1 = row1;
			final int frow0 = Math.max(row1 - solve_step, 0);
			// filling triangle single-threaded
			for (int row = row1-1; row >= frow0 ; row--) {
				for (int i = row+1; i <= frow1; i++) {
					x[row] -= x[i]*L[i][row];
				}
			}
			// Filling rectangle parallel
			if (frow1 > 0) {
				ai.set(0);
				for (int ithread = 0; ithread < threads.length; ithread++) {
					threads[ithread] = new Thread() {
						public void run() {
							for (int row = ai.getAndAdd(threads_chunk); row < frow0; row = ai.getAndAdd(threads_chunk)) {
								int row_lim = Math.min(frow0, row+threads_chunk);
								for (; row < row_lim; row++) {
									for (int col = frow0+1; col <= frow1; col++) {
										x[row] -= x[col]*L[col][row];
									}
								}								
							}
						}
					};
				}		      
				ImageDtt.startAndJoin(threads);
			}			
		}
		
		/**
		 * What was that?
		for (int k = n-1; k >= 0; k--) {
			for (int i = k+1; i < n ; i++) {
				x[k] -= x[i]*L[i][k];
			}
			//			            X[k] /= L[k][k];
		}
		*/
		return new Matrix(x,n);
	}
		
	public static void testCholesky(ImagePlus imp_src) {
		float [] fpixels = (float[]) imp_src.getProcessor().getPixels();
		int      width = imp_src.getWidth();
		int      height = imp_src.getHeight();
		int n = height;
		double [][] a = new double [n][n];
		double [][] b = new double [n][1];
		for (int i = 0; i < n; i++) {
			for (int j= 0; j < n; j++) {
				a[i][j] = fpixels[i*width+j];
			}
			b[i][0] = fpixels[i*width+n];
		}
		testCholesky(
				new Matrix(a), // Matrix wjtjlambda,
				new Matrix(b)); // Matrix jty)
	}
	public static Matrix[] testCholesky(
			Matrix wjtjlambda,
			Matrix jty) {
		Matrix wjtjlambda_copy0 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy1 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy2 = new Matrix(wjtjlambda.getArrayCopy());
		Matrix wjtjlambda_copy = new Matrix(wjtjlambda.getArrayCopy());
		double [] starts = new double[7];
		double start_time = (((double) System.nanoTime()) * 1E-9);
		
		CholeskyDecomposition choleskyDecomposition = new CholeskyDecomposition(wjtjlambda_copy);
		starts[0] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLDLTMulti choleskyLDLTMulti_single = new CholeskyLDLTMulti(wjtjlambda_copy0,0);
		starts[1] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLDLTMulti choleskyLDLTMulti_multi = new CholeskyLDLTMulti(wjtjlambda_copy1,1);
		starts[2] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLDLTMulti choleskyLDLTMulti_fast = new CholeskyLDLTMulti(wjtjlambda_copy2,2);
		starts[3] = (((double) System.nanoTime()) * 1E-9) - start_time;
		CholeskyLDLTMulti choleskyLDLTMulti = choleskyLDLTMulti_fast;
		Matrix mdelta_cholesky = choleskyDecomposition.solve(jty);
		starts[4] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_multi0 = choleskyLDLTMulti.solve0(jty);
		starts[5] = (((double) System.nanoTime()) * 1E-9) - start_time;
		Matrix mdelta_cholesky_multi =  choleskyLDLTMulti.solve(jty);
		starts[6] = (((double) System.nanoTime()) * 1E-9) - start_time;
		System.out.println("testCholesky(): choleskyDecomposition:    "+(starts[0])+" sec");
		System.out.println("testCholesky(): choleskyLDLTMulti_single: "+(starts[1]-starts[0])+" sec");
		System.out.println("testCholesky(): choleskyLDLTMulti_multi:  "+(starts[2]-starts[1])+" sec");
		System.out.println("testCholesky(): choleskyLDLTMulti_fast:   "+(starts[3]-starts[2])+" sec");
		
		System.out.println("testCholesky(): mdelta_cholesky:          "+(starts[4]-starts[3])+" sec");
		System.out.println("testCholesky(): mdelta_cholesky_multi0:   "+(starts[5]-starts[4])+" sec");
		System.out.println("testCholesky(): mdelta_cholesky_multi:    "+(starts[6]-starts[5])+" sec");
		return new Matrix[] {mdelta_cholesky, mdelta_cholesky_multi0, mdelta_cholesky_multi};
	}
	
	
	
}
