package com.elphel.imagej.tileprocessor;

import org.apache.commons.math3.geometry.euclidean.threed.Rotation;
import org.apache.commons.math3.geometry.euclidean.threed.RotationOrder;

import Jama.Matrix;

public class XyzQLma {
    private final static int  REGLEN =            1; // number of extra (regularization) samples
    
	private int               N =               0;
	private int               samples =         3;
	private int               samples_x =       3;
	private double []         last_rms =        null; // {rms, rms_pure}, matching this.vector
	private double []         good_or_bad_rms = null; // just for diagnostics, to read last (failed) rms
	private double []         initial_rms =     null; // {rms, rms_pure}, first-calcualted rms
	private double []         parameters_vector = null;
	private double []         x_vector =        null;
	private double []         y_vector =        null;
	private double []         y_inv_vector =    null;
	private double []         weights;     // normalized so sum is 1.0 for all - samples and extra regularization
	private double            pure_weight; // weight of samples only
	private double            xyz_weight;  // weight of all xyz, samples (weight of rotations - pure_weight- xyz_weight     
	private double []         last_ymfx =       null;
	private double [][]       last_jt =         null;
	private double []         fx_saved =        null;
	
	public void prepareLMA(
			double [][][] vect_x, // []{{x,y,z},{a,t,r}}
			double [][][] vect_y, // []{{x,y,z},{a,t,r}}
			double []     vect_w, // one per scene
			double        transl_cost, // 0.0 ... 1.0;
			double        reg_w,      // regularization weight [0..1) weight of q0^2+q1^2+q3^2 -1  
			final int     debug_level) {
		//MODE_XYZQ
		double []     quat0 = {1,0,0,0};
		N = vect_x.length;
		samples =   7; // 3 + quat0.length;
		samples_x = 7;
		pure_weight = 1.0 - reg_w;
		int extra_samples = (reg_w > 0) ? 3 : 0; // regularize q1..q3, do not pull q0
		x_vector = new double [samples_x * N];
		y_vector = new double [samples *   N + extra_samples];
		weights =  new double [samples *   N + extra_samples];
		parameters_vector = quat0.clone();
		double sw = 0;
		xyz_weight = 0;
		for (int i = 0; i < N; i++) {
			double sample_weight = ((vect_x[i]== null) || (vect_y[i]== null)) ? 0.0:((vect_w != null)?  vect_w[i] : 1.0);
			double tw = sample_weight * transl_cost;
			double rw = sample_weight * (1.0 - transl_cost);
			if ((vect_x[i]== null) || (vect_y[i]== null)) {
				for (int j = 0; j < samples; j++) {
					y_vector[samples * i + j] = 0.0;
					weights [samples * i + j] = 0.0;
				}				
				for (int j = 0; j < samples_x; j++) {
					x_vector[samples_x * i + j] = 0.0;
					y_vector[samples *   i + j] = 0.0;
					weights [samples *   i + j] = 0.0;
				}				
			} else {
				Rotation   rot_x = new Rotation(RotationOrder.YXZ, ErsCorrection.ROT_CONV,
						vect_x[i][1][0], vect_x[i][1][1], vect_x[i][1][2]);
				Rotation   rot_y = new Rotation(RotationOrder.YXZ, ErsCorrection.ROT_CONV,
						vect_y[i][1][0], vect_y[i][1][1], vect_y[i][1][2]);
				// Translation componets
				for (int j = 0; j < 3; j++) {
					x_vector[samples_x * i + j] = vect_x[i][0][j];
					y_vector[samples * i +   j] = vect_y[i][0][j];
					weights[samples * i +    j] = tw;
					sw += tw;
					xyz_weight += tw;
				}
				x_vector[samples_x * i + 3] = rot_x.getQ0();
				x_vector[samples_x * i + 4] = rot_x.getQ1();
				x_vector[samples_x * i + 5] = rot_x.getQ2();
				x_vector[samples_x * i + 6] = rot_x.getQ3();
				//Rotation components
				if (samples < samples_x) { // no Q0
					y_vector[samples * i + 3] = rot_y.getQ1();
					y_vector[samples * i + 4] = rot_y.getQ2();
					y_vector[samples * i + 5] = rot_y.getQ3();
					for (int j = 0; j < 3; j++) {
						weights[samples * i + 3 + j] = rw;
						sw += rw;
					}
				} else { // has Q0
					y_vector[samples * i + 3] = rot_y.getQ0();
					y_vector[samples * i + 4] = rot_y.getQ1();
					y_vector[samples * i + 5] = rot_y.getQ2();
					y_vector[samples * i + 6] = rot_y.getQ3();
					for (int j = 0; j < 4; j++) { // 02/14/2026 - trying q0 . Verify derivatives
						weights[samples * i + 3 + j] = rw;
						sw += rw;
					}
				}
			}
		}
		double k = (pure_weight)/sw;
		for (int i = 0; i < weights.length; i++) weights[i] *= k;
		xyz_weight *= k;
		if (extra_samples > 0) {
			double w = (1.0 - pure_weight)/extra_samples;
			for (int i = 0; i < extra_samples; i++) {
				weights [samples * N + i] = w;
				y_vector[samples * N + i] = 0.0; // target for q1..q3 is 0
			}
		}
		last_jt = new double [parameters_vector.length][];		
		if (debug_level > 0) {
			debugYfX ( "",   // String pfx,
					y_vector); // double [] data)
			debugYfX ( "PIMU-",   // String pfx,
					x_vector); // double [] data)
		}
		return;
	}

	private void debugYfX (
			String pfx,
			double [] data) {
	}
	
	public double [] getQuaternion() {
		return parameters_vector;
	}
	
	private double [] getRms4(double [] rms3) { // rms, rms_pure, xyz_rms, atr_rms
		double [] rms4 = new double [4];
		for (int i = 0; i < rms4.length; i++) rms4[i] = Double.NaN;
		System.arraycopy(rms3, 0, rms4, 0, Math.min(rms3.length, 3));
		if ((rms3.length > 2) && (pure_weight > xyz_weight)) {
			rms4[3] = (rms3[1] * pure_weight - rms3[2] * xyz_weight) / (pure_weight - xyz_weight);
		}
		return rms4;
	}
	
	public double [] getLastRms() {
		return getRms4(last_rms);
	}
	
	public double [] getInitialRms() {
		return getRms4(initial_rms);
	}
	
	public double [] getGoodOrBadRms() {
		return getRms4(good_or_bad_rms);
	}
	
	public double [] getLastFx() {
		return getFxDerivs(
				parameters_vector, // double []         vector,
				null,              // final double [][] jt
				-3);               // final int         debug_level
	}
	
	public double [] getLastFx(int debug_level) {
		return getFxDerivs(
				parameters_vector, // double []         vector,
				null,              // final double [][] jt
				debug_level);      // final int         debug_level
	}
	
	public void saveFx(int debug_level) {
		fx_saved = getFxDerivs(
				parameters_vector, // double []         vector,
				null,              // final double [][] jt
				debug_level);      // final int         debug_level
	}
	
	public double [] getSavedFx() {
		return fx_saved;
	}
	
	public double [] getX() {
		return x_vector;
	}
	
	public double [] getY() {
		return y_vector;
	}
	
	public double [] getW() {
		return weights;
	}
	
	private double [] getYminusFxWeighted(
			final double [] fx,
			final double [] rms_fp, // null or [3]
			boolean noNaNs) {
		final double [] wymfw = new double[fx.length];
		double s_rms = 0.0;
		double sxyz_rms = 0.0;
		double rms_pure = Double.NaN;
		double rms_pure_xyz = Double.NaN;
		for (int i = 0; i < fx.length; i++) {
			double d = y_vector[i] - fx[i];
			double wd = d * weights[i];
			if (Double.isNaN(wd)) {
				if (noNaNs) {
					if (rms_fp != null) {
						rms_fp[0] = Double.NaN;
						rms_fp[1] = Double.NaN;
						rms_fp[2] = Double.NaN;
					}
					return null;
				}
				wd = 0.0;
				d = 0.0;
			}
			if (i == (samples * N)) {
				if (pure_weight > 0.0) {
					rms_pure = Math.sqrt(s_rms / pure_weight);
				}
				if (xyz_weight > 0.0) {
					rms_pure_xyz = Math.sqrt(sxyz_rms / xyz_weight);
				}
			}
			wymfw[i] = wd;
			double wd2 = d * wd;
			s_rms += wd2;
			if ((i % samples) < 3) {
				sxyz_rms += wd2;
			}
		}
		double rms = Math.sqrt(s_rms); // assuming sum(weights) == 1.0
		if (Double.isNaN(rms_pure)) rms_pure = rms;
		if (Double.isNaN(rms_pure_xyz)) {
			rms_pure_xyz = (xyz_weight > 0.0) ? Math.sqrt(sxyz_rms / xyz_weight) : Double.NaN;
		}
		if (rms_fp != null) {
			rms_fp[0] = rms;
			rms_fp[1] = rms_pure;
			rms_fp[2] = rms_pure_xyz;
		}
		return wymfw;
	}
	
	private double [][] getWJtJlambda(
			final double lambda,
			final double [][] jt) {
		final int num_pars = jt.length;
		final int nup_points = jt[0].length;
		final double [][] wjtjl = new double[num_pars][num_pars];
		for (int i = 0; i < num_pars; i++) {
			for (int j = i; j < num_pars; j++) {
				double d = 0.0;
				for (int k = 0; k < nup_points; k++) {
					d += weights[k] * jt[i][k] * jt[j][k];
				}
				wjtjl[i][j] = d;
				if (i == j) {
					wjtjl[i][j] += d * lambda;
				} else {
					wjtjl[j][i] = d;
				}
			}
		}
		return wjtjl;
	}
	
	private boolean [] lmaStep(
			double lambda,
			double rms_diff,
			int debug_level) {
		boolean noNaNs = true;
		boolean [] rslt = {false, false};
		if (this.last_rms == null) { // first call
			last_rms = new double[3];
			double [] fx0 = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt
					debug_level);      // final int         debug_level
			last_ymfx = getYminusFxWeighted(
					fx0,      // final double [] fx
					last_rms, // final double [] rms_fp
					noNaNs);  // boolean noNaNs
			this.initial_rms = this.last_rms.clone();
			this.good_or_bad_rms = this.last_rms.clone();
			if (last_ymfx == null) {
				return null;
			}
		}
		Matrix y_minus_fx_weighted = new Matrix(this.last_ymfx, this.last_ymfx.length);
		Matrix wjtjlambda = new Matrix(getWJtJlambda(
				lambda,       // double lambda
				this.last_jt  // double[][] jt
		));
		Matrix jtjl_inv;
		try {
			jtjl_inv = wjtjlambda.inverse();
		} catch (RuntimeException e) {
			rslt[1] = true;
			return rslt;
		}
		Matrix jty = (new Matrix(this.last_jt)).times(y_minus_fx_weighted);
		Matrix mdelta = jtjl_inv.times(jty);
		double [] delta = mdelta.getColumnPackedCopy();
		double [] new_vector = parameters_vector.clone();
		for (int i = 0; i < parameters_vector.length; i++) {
			new_vector[i] += delta[i];
		}
		new_vector = normalizeQ(new_vector); // keep parameterized quaternion normalized

		double [] fx = getFxDerivs(
				new_vector, // double []         vector,
				last_jt,    // final double [][] jt
				debug_level // final int         debug_level
		);
		double [] rms = new double[3];
		last_ymfx = getYminusFxWeighted(
				fx,       // final double [] fx
				rms,      // final double [] rms_fp
				noNaNs);  // boolean noNaNs
		this.good_or_bad_rms = rms.clone();
		if ((rms[0] < this.last_rms[0]) && (last_ymfx != null)) { // improved
			rslt[0] = true;
			rslt[1] = rms[0] >= (this.last_rms[0] * (1.0 - rms_diff));
			this.last_rms = rms.clone();
			this.parameters_vector = new_vector.clone();
		} else { // worsened, restore state
			rslt[0] = false;
			rslt[1] = false;
			fx = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt
					debug_level);      // final int         debug_level
			last_ymfx = getYminusFxWeighted(
					fx,            // final double [] fx
					this.last_rms, // final double [] rms_fp
					noNaNs);       // boolean noNaNs
			if (last_ymfx == null) {
				return null;
			}
		}
		return rslt;
	}
	
	public int runLma( // <0 failed, >=0 last iteration index
			double lambda,
			double lambda_scale_good,
			double lambda_scale_bad,
			double lambda_max,
			double rms_diff,
			int    num_iter,
			boolean last_run,
			int    debug_level) {
		boolean [] rslt = {false, false};
		this.last_rms = null;
		int iter = 0;
		for (iter = 0; iter < num_iter; iter++) {
			rslt = lmaStep(
					lambda,     // double lambda
					rms_diff,   // double rms_diff
					debug_level // int debug_level
			);
			if (rslt == null) {
				return -1;
			}
			if (rslt[1]) {
				break;
			}
			if (rslt[0]) {
				lambda *= lambda_scale_good;
			} else {
				lambda *= lambda_scale_bad;
				if (lambda > lambda_max) {
					break;
				}
			}
		}
		if (!rslt[0]) {
			if ((last_rms != null) && (initial_rms != null) && (last_rms[0] < initial_rms[0])) {
				rslt[0] = true;
			}
		}
		return rslt[0] ? iter : -1;
	}

	private double [] getFxDerivs(
			double []         vector,
			final double [][] jt, // should be null or initialized with [vector.length][]
			final int         debug_level) {
		double [] fx = new double [weights.length];
		if (jt != null) {
			for (int i = 0; i < vector.length; i++) {
				jt[i] = new double [weights.length];
			}
		}
		for (int i = 0; i < N; i++) {
			int bix = samples_x * i;
			int bof = samples   * i;
			double [] xyzQ = new double[] {
					x_vector[bix],
					x_vector[bix + 1],
					x_vector[bix + 2],
					x_vector[bix + 3],
					x_vector[bix + 4],
					x_vector[bix + 5],
					x_vector[bix + 6]
			};
			double [] f7 = new double[7];
			if (jt != null) {
				double [][] jt7 = new double[4][7];
				applyQCore(xyzQ, vector, f7, jt7);
				for (int p = 0; p < vector.length; p++) {
					for (int k = 0; k < 7; k++) {
						jt[p][bof + k] = jt7[p][k];
					}
				}
			} else {
				applyQCore(xyzQ, vector, f7, null);
			}
			System.arraycopy(f7, 0, fx, bof, 7);
		}
		// Optional regularization columns at the end:
		// 3 extras -> pull q1..q3 to 0; if 4 extras, keep first one (q0) at 0 and pull q1..q3.
		int extra = weights.length - samples * N;
		int base = samples * N;
		if (extra >= 3) {
			int start = (extra == 4) ? 1 : 0;
			for (int j = 1; j <= 3; j++) {
				int col = base + start + (j - 1);
				if (col < fx.length) {
					fx[col] = vector[j];
					if (jt != null) {
						jt[j][col] = 1.0;
					}
				}
			}
		}
		return fx;
	}
	
	
	/**
	 * Apply unit quaternion Q[4] (rotation around global {0,0,0}) to a body at {x,y,z} in global frame
	 * and orientation of its frame represented by the unit quaternion {q0,q1,q21,q3}
	 * @param xyzQ array {x,y,z,q0,q1,q2,q3}
	 * @param Q array Q[4]
	 * @return {x',y',z',q0',q1',q2',q3'} consisting of new global coordinates x1,x2,x3 and orientation
	 * as unit quaternion {q0',q1',q2',q3'} 
	 */
	private static double [] applyQ(
			double [] xyzQ,
			double [] Q) {
		double [] result = new double[7];
		applyQCore(xyzQ, Q, result, null);
		return result;
	}
	/**
	 * Calculate transposed Jacobian [4][7] for the method applyQ(), assuming normalization (unity length) of Q,
	 * so partial derivative by each component of Q indirectly changes all other components to keep sum of squares equal 1 
	 * @param xyzQ array {x,y,z,q0,q1,q2,q3}
	 * @param Q array Q[4]
	 * @return [4][7] array of partial derivatives with respect to the components of Q
	 */
	private static double [][] getJT(
			double [] xyzQ,
			double [] Q) {
		double [][] jt = new double [4][7];
		applyQCore(xyzQ, Q, null, jt);
		return jt;
	}
	private static double [][] getJTdelta(
			double [] xyzQ,
			double [] Q,
			double    delta) {
		double [][] jt = new double [4][7];
		for (int i = 0; i < 4; i++) {
			double [] qp = Q.clone();
			double [] qm = Q.clone();
			qp[i] += 0.5*delta;
			qm[i] -= 0.5*delta;
			qp = normalizeQ(qp);
			qm = normalizeQ(qm);
			double [] fp = applyQ(xyzQ, qp);
			double [] fm = applyQ(xyzQ, qm);
			for (int j = 0; j < 7; j++) {
				jt[i][j] = (fp[j] - fm[j]) / delta;
			}
		}
		return jt;
	}

	private static double [] normalizeQ(double [] q) {
		double s2 = q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3];
		if (s2 <= 0.0) {
			return new double[] {1.0, 0.0, 0.0, 0.0};
		}
		double k = 1.0 / Math.sqrt(s2);
		return new double[] {k * q[0], k * q[1], k * q[2], k * q[3]};
	}

	private static double [] conjQ(double [] q) {
		return new double[] {q[0], -q[1], -q[2], -q[3]};
	}

	private static double [] mulQ(double [] a, double [] b) {
		double a0 = a[0];
		double a1 = a[1];
		double a2 = a[2];
		double a3 = a[3];
		double b0 = b[0];
		double b1 = b[1];
		double b2 = b[2];
		double b3 = b[3];
		return new double[] {
				a0 * b0 - a1 * b1 - a2 * b2 - a3 * b3,
				a0 * b1 + a1 * b0 + a2 * b3 - a3 * b2,
				a0 * b2 - a1 * b3 + a2 * b0 + a3 * b1,
				a0 * b3 + a1 * b2 - a2 * b1 + a3 * b0
		};
	}

	private static double [] rotateVecByQ(double [] v, double [] q) {
		double [] vq = new double[] {0.0, v[0], v[1], v[2]};
		double [] r = mulQ(mulQ(q, vq), conjQ(q));
		return new double[] {r[1], r[2], r[3]};
	}
	
	private static void applyQCore(
			double [] xyzQ,
			double [] Q,
			double [] out,      // nullable, [7]
			double [][] jt) {   // nullable, [4][7]
		double q0 = Q[0];
		double q1 = Q[1];
		double q2 = Q[2];
		double q3 = Q[3];
		double x = xyzQ[0];
		double y = xyzQ[1];
		double z = xyzQ[2];
		double r0 = xyzQ[3];
		double r1 = xyzQ[4];
		double r2 = xyzQ[5];
		double r3 = xyzQ[6];

		if (out != null) {
			out[0] = (q0 * q0 + q1 * q1 - q2 * q2 - q3 * q3) * x + 2.0 * (q1 * q2 - q0 * q3) * y + 2.0 * (q1 * q3 + q0 * q2) * z;
			out[1] = 2.0 * (q1 * q2 + q0 * q3) * x + (q0 * q0 - q1 * q1 + q2 * q2 - q3 * q3) * y + 2.0 * (q2 * q3 - q0 * q1) * z;
			out[2] = 2.0 * (q1 * q3 - q0 * q2) * x + 2.0 * (q2 * q3 + q0 * q1) * y + (q0 * q0 - q1 * q1 - q2 * q2 + q3 * q3) * z;
			out[3] = q0 * r0 - q1 * r1 - q2 * r2 - q3 * r3;
			out[4] = q0 * r1 + q1 * r0 + q2 * r3 - q3 * r2;
			out[5] = q0 * r2 - q1 * r3 + q2 * r0 + q3 * r1;
			out[6] = q0 * r3 + q1 * r2 - q2 * r1 + q3 * r0;
		}
		if (jt == null) {
			return;
		}
		// Derivatives of outputs by normalized quaternion components q0..q3
		double [][] dfdq = new double[7][4];
		// d(x')/dq
		dfdq[0][0] =  2.0 * ( q0 * x - q3 * y + q2 * z);
		dfdq[0][1] =  2.0 * ( q1 * x + q2 * y + q3 * z);
		dfdq[0][2] =  2.0 * (-q2 * x + q1 * y + q0 * z);
		dfdq[0][3] =  2.0 * (-q3 * x - q0 * y + q1 * z);
		// d(y')/dq
		dfdq[1][0] =  2.0 * ( q3 * x + q0 * y - q1 * z);
		dfdq[1][1] =  2.0 * ( q2 * x - q1 * y - q0 * z);
		dfdq[1][2] =  2.0 * ( q1 * x + q2 * y + q3 * z);
		dfdq[1][3] =  2.0 * ( q0 * x - q3 * y + q2 * z);
		// d(z')/dq
		dfdq[2][0] =  2.0 * (-q2 * x + q1 * y + q0 * z);
		dfdq[2][1] =  2.0 * ( q3 * x + q0 * y - q1 * z);
		dfdq[2][2] =  2.0 * (-q0 * x + q3 * y - q2 * z);
		dfdq[2][3] =  2.0 * ( q1 * x + q2 * y + q3 * z);
		// d(q*q_or)/dq
		dfdq[3][0] = r0;  dfdq[3][1] = -r1; dfdq[3][2] = -r2; dfdq[3][3] = -r3;
		dfdq[4][0] = r1;  dfdq[4][1] =  r0; dfdq[4][2] =  r3; dfdq[4][3] = -r2;
		dfdq[5][0] = r2;  dfdq[5][1] = -r3; dfdq[5][2] =  r0; dfdq[5][3] =  r1;
		dfdq[6][0] = r3;  dfdq[6][1] =  r2; dfdq[6][2] = -r1; dfdq[6][3] =  r0;

		// Chain rule: d/dQ_i = sum_j (d/dq_j) * (dq_j/dQ_i), with q = Q/|Q|
		double s2 = Q[0] * Q[0] + Q[1] * Q[1] + Q[2] * Q[2] + Q[3] * Q[3];
		if (s2 <= 0.0) {
			for (int i = 0; i < 4; i++) {
				for (int k = 0; k < 7; k++) {
					jt[i][k] = 0.0;
				}
			}
			return;
		}
		double s = Math.sqrt(s2);
		double invs = 1.0 / s;
		double [] qn = new double[] {Q[0] * invs, Q[1] * invs, Q[2] * invs, Q[3] * invs};
		for (int i = 0; i < 4; i++) {
			for (int outi = 0; outi < 7; outi++) {
				double v = 0.0;
				for (int j = 0; j < 4; j++) {
					double dqjdQi = ((j == i) ? invs : 0.0) - (qn[j] * qn[i]) * invs;
					v += dfdq[outi][j] * dqjdQi;
				}
				jt[i][outi] = v;
			}
		}
	}

}
