/**
 **
 ** QuaternionLma - Find quaternion to best transform a set of input 3D vectors
 **                into a set of output 3D vectors
 **
 ** Copyright (C) 2023 Elphel, Inc.
 **
 ** -----------------------------------------------------------------------------**
 **
 **  QuaternionLma.java is free software: you can redistribute it and/or modify
 **  it under the terms of the GNU General Public License as published by
 **  the Free Software Foundation, either version 3 of the License, or
 **  (at your option) any later version.
 **
 **  This program is distributed in the hope that it will be useful,
 **  but WITHOUT ANY WARRANTY; without even the implied warranty of
 **  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 **  GNU General Public License for more details.
 **
 **  You should have received a copy of the GNU General Public License
 **  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 ** -----------------------------------------------------------------------------**
 **
 */

package com.elphel.imagej.tileprocessor;

import java.util.concurrent.atomic.AtomicInteger;

import Jama.Matrix;

public class QuaternionLma {
    private final static int  REGLEN =            1; // number of extra (regularization) samples  
	private int               N =               0; 
	private double []         last_rms =        null; // {rms, rms_pure}, matching this.vector
	private double []         good_or_bad_rms = null; // just for diagnostics, to read last (failed) rms
	private double []         initial_rms =     null; // {rms, rms_pure}, first-calcualted rms
	private double []         parameters_vector = null;
	private double []         x_vector =        null;
	private double []         y_vector =        null;
	private double []         weights; // normalized so sum is 1.0 for all - samples and extra regularization
	private double            pure_weight; // weight of samples only	
	private double []         last_ymfx =       null;
	private double [][]       last_jt =         null;
	
	public double [] getQuaternion() {
		return parameters_vector;
	}
	
	public double [] getLastRms() {
		return last_rms;
	}
	public double [] getInitialRms() {
		return initial_rms;
	}
	public void prepareLMA(
			double [][] vect_x,
			double [][] vect_y,
			double [][] vect_w,
			double      reg_w, // regularization weight [0..1) weight of q0^2+q1^2+q3^2 -1  
			double []   quat0,
			final int   debug_level) {
		N = vect_x.length;
		pure_weight = 1.0 - reg_w;
		x_vector = new double [3* N];
		y_vector = new double [3* N + REGLEN];
		weights =  new double [3* N + REGLEN];
		parameters_vector = quat0.clone();
		double sw = 0;
		for (int i = 0; i < N; i++) {
			for (int j = 0; j < 3; j++) {
				x_vector[3 * i + j] = vect_x[i][j];
				y_vector[3 * i + j] = vect_y[i][j];
				double w = (vect_w != null)? vect_w[i][j] : 1.0;
				weights[3*i + j] = w;
				sw += w;
			}
		}
		double k = (pure_weight)/sw;
		for (int i = 0; i < weights.length; i++) weights[i] *= k;
		weights[3 * N] = 1.0 - pure_weight;
		y_vector[3 * N] = 1.0;
		last_jt = new double [parameters_vector.length][];		
	}
	
	// TODO: Consider adding differences between x and y for regularization (or it won't work)
	// goal - to minimize "unneded" rotation along the commonn axis
	private double [] getFxDerivs(
			double []         vector,
			final double [][] jt, // should be null or initialized with [vector.length][]
			final int         debug_level) {
		double [] fx = new double [weights.length];
		final double q0 = vector[0]; 
		final double q1 = vector[1]; 
		final double q2 = vector[2]; 
		final double q3 = vector[3];
		if (jt != null) {
			for (int i = 0; i < vector.length; i++) {
				jt[i] = new double [weights.length];
				jt[i][3 * N] = 2 * vector[i];
			}
		}
		fx[3 * N] = q0*q0 + q1*q1 + q2 * q2 + q3*q3;
		for (int i = 0; i < N; i++) {
			int i3 = 3 * i;
			final double x = x_vector[i3 + 0];
			final double y = x_vector[i3 + 1];
			final double z = x_vector[i3 + 2];
			final double s = q1 * x + q2 * y + q3 * z;
			fx[i3 + 0] = 2 * (q0 * (x * q0 - (q2 * z - q3 * y)) + s * q1) - x;
			fx[i3 + 1] = 2 * (q0 * (y * q0 - (q3 * x - q1 * z)) + s * q2) - y;
			fx[i3 + 2] = 2 * (q0 * (z * q0 - (q1 * y - q2 * x)) + s * q3) - z;
			if (jt != null) {
				jt[0][i3 + 0] = 4*x*q0 - 2*z*q2 + 2*y*q3;
				jt[1][i3 + 0] = 2*s +    2*q1*x;
				jt[2][i3 + 0] = 2*z*q0 + 2*q1*y;
				jt[3][i3 + 0] = 2*y*q0 + 2*q1*z;

				jt[0][i3 + 1] = 4*y*q0 - 2*x*q3 + 2*z*q1;
				jt[1][i3 + 1] = 2*z*q0 + 2*x*q2; 
				jt[2][i3 + 1] = 2*s +    2*y*q2;
				jt[3][i3 + 1] =-2*x*q0+  2*z*q2;

				jt[0][i3 + 2] = 4*z*q0 - 2*y*q1 + 2*x*q2;
				jt[1][i3 + 2] =-2*y*q0 + 2*x*q3;
				jt[2][i3 + 2] = 2*x*q0 + 2*y*q3;
				jt[3][i3 + 2] = 2*s    + 2*z*q3;
			}
		}
		return fx;
	}
	
	private double [] getYminusFxWeighted(
			final double []   fx,
			final double []   rms_fp // null or [2]
			) {
		final double []     wymfw =       new double [fx.length];
		double s_rms=0; 
		double rms_pure=0; 
		for (int i = 0; i < fx.length; i++) {
			double d = y_vector[i] - fx[i];
			double wd = d * weights[i];
			if (Double.isNaN(wd)) {
				System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
						", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
				wd = 0.0;
				d = 0.0;
			}
			if (i == (3 * N)) {
				rms_pure = Math.sqrt(s_rms/pure_weight);;	
			}
			wymfw[i] = wd;
			s_rms += d * wd;
		}
		double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0;

		if (rms_fp != null) {
			rms_fp[0] = rms;
			rms_fp[1] = rms_pure;
		}

		return wymfw;
	}
	
	// reusing multithreaded
	private double [][] getWJtJlambda( // USED in lwir
			final double      lambda,
			final double [][] jt)
	{
		final int num_pars = jt.length;
		final int num_pars2 = num_pars * num_pars;
		final int nup_points = jt[0].length;
		final double [][] wjtjl = new double [num_pars][num_pars];
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int indx = ai.getAndIncrement(); indx < num_pars2; indx = ai.getAndIncrement()) {
						int i = indx / num_pars;
						int j = indx % num_pars;
						if (j >= i) {
							double d = 0.0;
							for (int k = 0; k < nup_points; k++) {
								if (jt[i][k] != 0) {
									d+=0;
								}
								d += weights[k]*jt[i][k]*jt[j][k];
							}
							wjtjl[i][j] = d;
							if (i == j) {
								wjtjl[i][j] += d * lambda;
							} else {
								wjtjl[j][i] = d;
							}
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return wjtjl;
	}
	
	
	public int runLma( // <0 - failed, >=0 iteration number (1 - immediately)
			double lambda,           // 0.1
			double lambda_scale_good,// 0.5
			double lambda_scale_bad, // 8.0
			double lambda_max,       // 100
			double rms_diff,         // 0.001
			int    num_iter,         // 20
			boolean last_run,
			int    debug_level) {
		boolean [] rslt = {false,false};
		this.last_rms = null; // remove?
		int iter = 0;
		for (iter = 0; iter < num_iter; iter++) {
			rslt =  lmaStep(
					lambda,
					rms_diff,
					debug_level);
			if (rslt == null) {
				return -1; // false; // need to check
			}
			if (debug_level > 1) {
				System.out.println("LMA step "+iter+": {"+rslt[0]+","+rslt[1]+"} full RMS= "+good_or_bad_rms[0]+
						" ("+initial_rms[0]+"), pure RMS="+good_or_bad_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
			}
			if (rslt[1]) {
				break;
			}
			if (rslt[0]) { // good
				lambda *= lambda_scale_good;
			} else {
				lambda *= lambda_scale_bad;
				if (lambda > lambda_max) {
					break; // not used in lwir
				}
			}
		}
		if (rslt[0]) { // better
			if (iter >= num_iter) { // better, but num tries exceeded
				if (debug_level > 1) System.out.println("Step "+iter+": Improved, but number of steps exceeded maximal");
			} else {
				if (debug_level > 1) System.out.println("Step "+iter+": LMA: Success");
			}

		} else { // improved over initial ?
			if (last_rms[0] < initial_rms[0]) { // NaN
				rslt[0] = true;
				if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge, but result improved over initial");
			} else {
				if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge");
			}
		}
		boolean show_intermediate = true;
		if (show_intermediate && (debug_level > 0)) {
			System.out.println("LMA: full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
		}
		if (debug_level > 2){ 
			String [] lines1 = printOldNew(false); // boolean allvectors)
			System.out.println("iteration="+iter);
			for (String line : lines1) {
				System.out.println(line);
			}
		}
		if (debug_level > 0) {
			if ((debug_level > 1) ||  last_run) { // (iter == 1) || last_run) {
				if (!show_intermediate) {
					System.out.println("LMA: iter="+iter+",   full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
				}
				String [] lines = printOldNew(false); // boolean allvectors)
				for (String line : lines) {
					System.out.println(line);
				}
			}
		}
		if ((debug_level > -2) && !rslt[0]) { // failed
			if ((debug_level > 1) || (iter == 1) || last_run) {
				System.out.println("LMA failed on iteration = "+iter);
				String [] lines = printOldNew(true); // boolean allvectors)
				for (String line : lines) {
					System.out.println(line);
				}
			}
			System.out.println();
		}

		return rslt[0]? iter : -1;
	}
	
	private boolean [] lmaStep(
			double lambda,
			double rms_diff,
			int debug_level) {
		boolean [] rslt = {false,false};
		// maybe the following if() branch is not needed - already done in prepareLMA !
		if (this.last_rms == null) { //first time, need to calculate all (vector is valid)
			last_rms = new double[2];
			if (debug_level > 1) {
				System.out.println("lmaStep(): first step");
			}
			double [] fx = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);      // final int         debug_level)
			last_ymfx = getYminusFxWeighted(
					fx, // final double []   fx,
					last_rms); // final double []   rms_fp // null or [2]
			this.initial_rms = this.last_rms.clone();
			this.good_or_bad_rms = this.last_rms.clone();

			if (debug_level > -1) { // temporary
				/*
				dbgYminusFxWeight(
						this.last_ymfx,
						this.weights,
						"Initial_y-fX_after_moving_objects");
                */
			}
			if (last_ymfx == null) {
				return null; // need to re-init/restart LMA
			}
			// TODO: Restore/implement
			if (debug_level > 3) {
				/*
				 dbgJacobians(
							corr_vector, // GeometryCorrection.CorrVector corr_vector,
							1E-5, // double delta,
							true); //boolean graphic)
				*/
			}
		}
		Matrix y_minus_fx_weighted = new Matrix(this.last_ymfx, this.last_ymfx.length);

		Matrix wjtjlambda = new Matrix(getWJtJlambda(
				lambda, // *10, // temporary
				this.last_jt)); // double [][] jt)
		
		if (debug_level>2) {
			System.out.println("JtJ + lambda*diag(JtJ");
			wjtjlambda.print(18, 6);
		}
		Matrix jtjl_inv = null;
		try {
			jtjl_inv = wjtjlambda.inverse(); // check for errors
		} catch (RuntimeException e) {
			rslt[1] = true;
			if (debug_level > 0) {
				System.out.println("Singular Matrix!");
			}

			return rslt;
		}
		if (debug_level>2) {
			System.out.println("(JtJ + lambda*diag(JtJ).inv()");
			jtjl_inv.print(18, 6);
		}
//last_jt has NaNs
		Matrix jty = (new Matrix(this.last_jt)).times(y_minus_fx_weighted);
		if (debug_level>2) {
			System.out.println("Jt * (y-fx)");
			jty.print(18, 6);
		}
		
		
		Matrix mdelta = jtjl_inv.times(jty);
		if (debug_level>2) {
			System.out.println("mdelta");
			mdelta.print(18, 6);
		}

		double scale = 1.0;
		double []  delta =      mdelta.getColumnPackedCopy();
		double []  new_vector = parameters_vector.clone();
		for (int i = 0; i < parameters_vector.length; i++) {
			new_vector[i] += scale * delta[i];
		}
		
		
		double [] fx = getFxDerivs(
				new_vector, // double []         vector,
				last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
				debug_level);      // final int         debug_level)
		double [] rms = new double[2];
		last_ymfx = getYminusFxWeighted(
				fx, // final double []   fx,
				rms); // final double []   rms_fp // null or [2]
		if (debug_level > 2) {
			/*
			dbgYminusFx(this.last_ymfx, "next y-fX");
			dbgXY(new_vector, "XY-correction");
			*/
		}

		if (last_ymfx == null) {
			return null; // need to re-init/restart LMA
		}

		this.good_or_bad_rms = rms.clone();
		if (rms[0] < this.last_rms[0]) { // improved
			rslt[0] = true;
			rslt[1] = rms[0] >=(this.last_rms[0] * (1.0 - rms_diff));
			this.last_rms = rms.clone();

			this.parameters_vector = new_vector.clone();
			if (debug_level > 2) {
				// print vectors in some format
				/*
				System.out.print("delta: "+corr_delta.toString()+"\n");
				System.out.print("New vector: "+new_vector.toString()+"\n");
				System.out.println();
				*/
			}
		} else { // worsened
			rslt[0] = false;
			rslt[1] = false; // do not know, caller will decide
			// restore state
			fx = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);      // final int         debug_level)
			last_ymfx = getYminusFxWeighted(
					fx, // final double []   fx,
					this.last_rms); // final double []   rms_fp // null or [2]
			if (last_ymfx == null) {
				return null; // need to re-init/restart LMA
			}
			if (debug_level > 2) {
				/*
				 dbgJacobians(
							corr_vector, // GeometryCorrection.CorrVector corr_vector,
							1E-5, // double delta,
							true); //boolean graphic)
							*/
			}
		}
		return rslt;
	}
	
	
	
	
	//TODO: implement
	public String [] printOldNew(boolean allvectors) {
		return new String[] {};
	}

}
