package com.elphel.imagej.ims;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;

import com.elphel.imagej.cameras.CLTParameters;
import com.elphel.imagej.tileprocessor.ImageDtt;

import Jama.Matrix;

/**
 * Quaternion correction using a single rotation about a fixed axis.
 * The axis is defined as the cross product of the dominant motion directions
 * for IMS (GNSS/IMU) and camera tracks (best-fit line directions).
 * The sole parameter is the rotation angle around this axis.
 */
public class QuatCorrAxisLMA {
	private int               N =               0;
	private int               samples =         0;
	private double []         axis =            null; // unit axis
	private double []         last_rms =        null; // {rms, rms_pure}
	private double []         good_or_bad_rms = null;
	private double []         initial_rms =     null;
	private double []         parameters_vector = null; // [angle]
	private double []         x_vector =        null;
	private double []         y_vector =        null;
	private double []         weights;          // normalized so sum is 1.0
	private double            pure_weight;      // weight of samples only
	private double []         last_ymfx =       null;
	private double [][]       last_jt =         null;

	public double [] getLastRms() {
		return last_rms;
	}

	public double [] getInitialRms() {
		return initial_rms;
	}

	public double [] getVector() { // [angle]
		return parameters_vector;
	}

	public double [] getAxis() {
		return axis;
	}

	public double [] getNormQuat() {
		return angleToQuat(parameters_vector[0]);
	}

	/**
	 * Calculate quaternion (single-axis) to rotate array of 3d vectors imx_xyz[][3] to
	 * array cam_xyz[][3] with LMA using an axis defined by dominant motion directions.
	 * @param clt_parameters parameters, including for the LMA
	 * @param ims_xyz array of the IMS coordinates [...][3], relative to a reference scene
	 * @param cam_xyz array of the camera coordinates [...][3], relative to a same reference scene
	 * @param weights weights per vector or null (will use the same weights for all vectors)
	 * @param rmse null or double [1] to output LMA RMSE. if double[2] will provide initial RMSE (unit quaternion)
	 * @param debugLevel
	 * @return normalized quaternion as double[4] or null if failed
	 */
	public static double [] getQuatLMA(
			CLTParameters   clt_parameters,
			double [][]     ims_xyz,
			double [][]     cam_xyz,
			double []       weights,
			double []       rmse, // double[1] or null
			int             debugLevel) {
		debugLevel+=3;
		QuatCorrAxisLMA quatCorrAxisLMA = new QuatCorrAxisLMA();
		quatCorrAxisLMA.prepareLMA(
				ims_xyz,
				cam_xyz,
				weights);
		int OK = quatCorrAxisLMA.runLma(
				clt_parameters.imp.imsq_lambda,
				clt_parameters.imp.imsq_lambda_scale_good,
				clt_parameters.imp.imsq_lambda_scale_bad,
				clt_parameters.imp.imsq_lambda_max,
				clt_parameters.imp.imsq_rms_diff,
				clt_parameters.imp.imsq_num_iter,
				debugLevel);
		if (OK < 0) {
			return null;
		}
		if (rmse != null) {
			if (rmse.length > 0) {
				rmse[0] = quatCorrAxisLMA.getLastRms()[0];
				if (rmse.length > 1) {
					rmse[1] = quatCorrAxisLMA.getInitialRms()[0];
				}
			}
		}
		if (debugLevel > -3) {
			double [] p = quatCorrAxisLMA.getVector();
			double [] qn = quatCorrAxisLMA.getNormQuat(); // [w,x,y,z]
			double l2 = Math.sqrt(qn[0]*qn[0]+qn[1]*qn[1]+qn[2]*qn[2]+qn[3]*qn[3]);
			System.out.println("QuatCorrAxisLMA: angle= "+p[0]+
					", len="+l2+", RMSE="+quatCorrAxisLMA.getLastRms()[0]+" ("+quatCorrAxisLMA.getInitialRms()[0]+")");
			System.out.println("QuatCorrAxisLMA: quat normalized= ["+qn[0]+", "+qn[1]+", "+qn[2]+", "+qn[3]+"]");
		}
		return quatCorrAxisLMA.getNormQuat();
	}

	public void prepareLMA(
			double [][]     ims_xyz,
			double [][]     cam_xyz,
			double []       sample_weights) {
		N = 3;
		samples = ims_xyz.length;
		if (cam_xyz.length != samples) {
			throw new IllegalArgumentException ("QuatCorrAxisLMA.prepareLMA: ims_xyz and cam_xyz should have the same length");
		}
		if (sample_weights == null) {
			sample_weights = new double[samples];
			Arrays.fill(sample_weights, 1.0);
		}
		boolean [] valid = new boolean [samples];
		if (sample_weights.length != samples) {
			throw new IllegalArgumentException ("QuatCorrAxisLMA.prepareLMA: sample_weights should be null or have the same length as data arrays");
		}
		y_vector = new double[N * samples];
		x_vector = new double[N * samples];
		double sum_w = 0;
		for (int n = 0; n < samples; n++) if ((ims_xyz[n]!=null) &&(cam_xyz[n]!=null)){
			valid[n]=true;
			System.arraycopy(ims_xyz[n], 0, x_vector, N*n, N);
			System.arraycopy(cam_xyz[n], 0, y_vector, N*n, N);
			sum_w += sample_weights[n];
		}
		double ws = 1.0/N/sum_w;
		weights = new double [N * samples];
		for (int n = 0; n < samples; n++) if (valid[n]) {
			Arrays.fill(weights, N*n, N*(n+1), ws*sample_weights[n]);
		}
		// Estimate dominant directions and axis
		double[] dirIms = principalDirection(ims_xyz, valid);
		double[] dirCam = principalDirection(cam_xyz, valid);
		double[] axisRaw = cross(dirIms, dirCam);
		double axisNorm = norm(axisRaw);
		if (axisNorm < 1e-12) { // nearly parallel, fall back to dirIms
			axisRaw = dirIms.clone();
			axisNorm = norm(axisRaw);
		}
		for (int i = 0; i < 3; i++) {
			axisRaw[i] /= axisNorm;
		}
		this.axis = axisRaw;
		// Initial angle to align directions
		double dotDir = dot(dirIms, dirCam);
		double crossMag = norm(cross(dirIms, dirCam));
		double angle0 = Math.atan2(crossMag, dotDir);
		parameters_vector = new double [] {angle0};
		last_jt = new double [parameters_vector.length][];
		return;
	}

	private static double dot(double[] a, double[] b) {
		return a[0]*b[0] + a[1]*b[1] + a[2]*b[2];
	}

	private static double norm(double[] a) {
		return Math.sqrt(dot(a,a));
	}

	private static double[] cross(double[] a, double[] b) {
		return new double[] {
				a[1]*b[2] - a[2]*b[1],
				a[2]*b[0] - a[0]*b[2],
				a[0]*b[1] - a[1]*b[0]};
	}

	private static double[] principalDirection(double[][] pts, boolean[] valid) {
		double[] mean = new double[3];
		int count = 0;
		for (int i = 0; i < pts.length; i++) if (valid[i] && (pts[i]!=null)) {
			mean[0] += pts[i][0];
			mean[1] += pts[i][1];
			mean[2] += pts[i][2];
			count++;
		}
		if (count == 0) return new double[] {1,0,0};
		mean[0] /= count; mean[1] /= count; mean[2] /= count;
		double[][] cov = new double[3][3];
		for (int i = 0; i < pts.length; i++) if (valid[i] && (pts[i]!=null)) {
			double dx = pts[i][0]-mean[0];
			double dy = pts[i][1]-mean[1];
			double dz = pts[i][2]-mean[2];
			cov[0][0] += dx*dx; cov[0][1] += dx*dy; cov[0][2] += dx*dz;
			cov[1][0] += dy*dx; cov[1][1] += dy*dy; cov[1][2] += dy*dz;
			cov[2][0] += dz*dx; cov[2][1] += dz*dy; cov[2][2] += dz*dz;
		}
		double[] v = new double[] {1,0,0};
		for (int iter = 0; iter < 20; iter++) {
			double[] nv = new double[3];
			nv[0] = cov[0][0]*v[0] + cov[0][1]*v[1] + cov[0][2]*v[2];
			nv[1] = cov[1][0]*v[0] + cov[1][1]*v[1] + cov[1][2]*v[2];
			nv[2] = cov[2][0]*v[0] + cov[2][1]*v[1] + cov[2][2]*v[2];
			double nrm = norm(nv);
			if (nrm < 1e-12) break;
			v[0] = nv[0]/nrm; v[1] = nv[1]/nrm; v[2] = nv[2]/nrm;
		}
		return v;
	}

	private double [] getFxDerivs(
			double []         vector, // [angle]
			final double [][] jt, // should be null or initialized with [vector.length][]
			final int         debug_level) {
		double angle = vector[0];
		double [] q = angleToQuat(angle);
		double [][] R = QuatCorrLMA.quaternionToRotationMatrix(q);
		double [] fx = new double [weights.length];

		if (jt != null) {
			for (int i = 0; i < vector.length; i++) {
				jt[i] = new double [weights.length];
			}
		}
		for (int n = 0; n < samples; n++) {
			double [] v1 = new double[3];
			double [] v =  new double[3];
			System.arraycopy(x_vector, n * 3, v, 0, 3);
			for (int i = 0; i < 3; i++) {
				for (int k = 0; k < 3; k++) {
					v1[i] += R[i][k]*v[k];
				}
			}
			System.arraycopy(v1, 0, fx, 3 * n, 3);
			if (jt != null) {
				double [] dv = axisCross(v1);
				System.arraycopy(dv, 0, jt[0], 3 * n, 3);
			}
		}
		return fx;
	}

	private double [] getYminusFxWeighted(
			final double []   fx,
			final double []   rms_fp // null or [2]
			) {
		final double []     wymfw =       new double [fx.length];
		double s_rms=0;
		double rms_pure=Double.NaN;
		for (int i = 0; i < fx.length; i++) {
			double d = y_vector[i] - fx[i];
			double wd = d * weights[i];
			if (Double.isNaN(wd)) {
				wd = 0.0;
				d = 0.0;
			}
			if (i == (samples * N)) {
				rms_pure = Math.sqrt(s_rms/pure_weight);
			}
			wymfw[i] = wd;
			s_rms += d * wd;
		}
		double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0;
		if (Double.isNaN(rms_pure)) {
			rms_pure=rms;
		}
		if (rms_fp != null) {
			rms_fp[0] = rms;
			rms_fp[1] = rms_pure;
		}
		return wymfw;
	}

	private double [][] getWJtJlambda(
			final double      lambda,
			final double [][] jt)
	{
		final int num_pars = jt.length;
		final int num_pars2 = num_pars * num_pars;
		final int nup_points = jt[0].length;
		final double [][] wjtjl = new double [num_pars][num_pars];
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int indx = ai.getAndIncrement(); indx < num_pars2; indx = ai.getAndIncrement()) {
						int i = indx / num_pars;
						int j = indx % num_pars;
						if (j >= i) {
							double d = 0.0;
							for (int k = 0; k < nup_points; k++) {
								d += weights[k]*jt[i][k]*jt[j][k];
							}
							wjtjl[i][j] = d;
							if (i == j) {
								wjtjl[i][j] += d * lambda;
							} else {
								wjtjl[j][i] = d;
							}
						}
					}
				}
			};
		}
		ImageDtt.startAndJoin(threads);
		return wjtjl;
	}

	public int runLma(
			double lambda,
			double lambda_scale_good,
			double lambda_scale_bad,
			double lambda_max,
			double rms_diff,
			int    num_iter,
			int    debug_level) {
		boolean [] rslt = {false,false};
		this.last_rms = null;
		int iter = 0;
		for (iter = 0; iter < num_iter; iter++) {
			rslt =  lmaStep(
					lambda,
					rms_diff,
					debug_level);
			if (rslt == null) {
				return -1;
			}
			if (debug_level > 1) {
				System.out.println("Axis LMA step "+iter+": {"+rslt[0]+","+rslt[1]+"} full RMS= "+good_or_bad_rms[0]+
						" ("+initial_rms[0]+"), pure RMS="+good_or_bad_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda+
						", angle="+parameters_vector[0]);
			}
			if (rslt[1]) {
				break;
			}
			if (rslt[0]) { // good
				lambda *= lambda_scale_good;
			} else {
				lambda *= lambda_scale_bad;
				if (lambda > lambda_max) {
					break;
				}
			}
		}
		if (!rslt[0] && (last_rms != null) && (initial_rms != null) && (last_rms[0] < initial_rms[0])) {
			rslt[0] = true;
		}
		return rslt[0]? iter : -1;
	}

	private boolean [] lmaStep(
			double lambda,
			double rms_diff,
			int debug_level) {
		boolean [] rslt = {false,false};
		if (this.last_rms == null) { // first time
			last_rms = new double[2];
			double [] fx = getFxDerivs(
					parameters_vector,
					last_jt,
					debug_level);
			last_ymfx = getYminusFxWeighted(
					fx,
					last_rms);
			this.initial_rms = this.last_rms.clone();
			this.good_or_bad_rms = this.last_rms.clone();
			if (last_ymfx == null) {
				return null;
			}
		}
		Matrix y_minus_fx_weighted = new Matrix(this.last_ymfx, this.last_ymfx.length);

		Matrix wjtjlambda = new Matrix(getWJtJlambda(
				lambda,
				this.last_jt));
		Matrix jtjl_inv = null;
		try {
			jtjl_inv = wjtjlambda.inverse();
		} catch (RuntimeException e) {
			rslt[1] = true;
			return rslt;
		}
		Matrix jty = (new Matrix(this.last_jt)).times(y_minus_fx_weighted);

		Matrix mdelta = jtjl_inv.times(jty);

		double []  delta =      mdelta.getColumnPackedCopy();
		double []  new_vector = parameters_vector.clone();
		for (int i = 0; i < parameters_vector.length; i++) {
			new_vector[i] += delta[i];
		}

		double [] fx = getFxDerivs(
				new_vector,
				last_jt,
				debug_level);
		double [] rms = new double[2];
		last_ymfx = getYminusFxWeighted(
				fx,
				rms);

		if (last_ymfx == null) {
			return null;
		}

		this.good_or_bad_rms = rms.clone();
		if (rms[0] < this.last_rms[0]) { // improved
			rslt[0] = true;
			rslt[1] = rms[0] >=(this.last_rms[0] * (1.0 - rms_diff));
			this.last_rms = rms.clone();

			this.parameters_vector = new_vector.clone();
		} else { // worsened
			rslt[0] = false;
			rslt[1] = false;
			// restore state
			fx = getFxDerivs(
					parameters_vector,
					last_jt,
					debug_level);
			last_ymfx = getYminusFxWeighted(
					fx,
					this.last_rms);
			if (last_ymfx == null) {
				return null;
			}
		}
		return rslt;
	}

	private double[] angleToQuat(double angle) {
		double half = 0.5 * angle;
		double sh = Math.sin(half);
		double ch = Math.cos(half);
		return new double[] {ch, axis[0]*sh, axis[1]*sh, axis[2]*sh};
	}

	private double[] axisCross(double[] v) {
		return new double[] {
				axis[1]*v[2] - axis[2]*v[1],
				axis[2]*v[0] - axis[0]*v[2],
				axis[0]*v[1] - axis[1]*v[0]};
	}
}
