/**
 **
 ** OrthoPairLMA - Fit a pair of fixed-scale orthographic image by modifying
 **                affine transform matrix of the second one using 2D phase 
 **                correlation implemented in GPU
 **
 ** Copyright (C) 2024 Elphel, Inc.
 **
 ** -----------------------------------------------------------------------------**
 **
 **  OrthoPairLMA.java is free software: you can redistribute it and/or modify
 **  it under the terms of the GNU General Public License as published by
 **  the Free Software Foundation, either version 3 of the License, or
 **  (at your option) any later version.
 **
 **  This program is distributed in the hope that it will be useful,
 **  but WITHOUT ANY WARRANTY; without even the implied warranty of
 **  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 **  GNU General Public License for more details.
 **
 **  You should have received a copy of the GNU General Public License
 **  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 ** -----------------------------------------------------------------------------**
 **
 */

package com.elphel.imagej.orthomosaic;

import java.awt.Rectangle;
import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;

import com.elphel.imagej.common.ShowDoubleFloatArrays;
import com.elphel.imagej.gpu.GPUTileProcessor;
import com.elphel.imagej.tileprocessor.ImageDtt;
import com.elphel.imagej.tileprocessor.QuadCLT;

import Jama.Matrix;

public class OrthoPairLMA {
	private boolean           thread_invariant= true;
	private int               N =               0; // number of tiles in WOI:  woi.width * woi.height
	private double [][]       aff =             null; // source affine transform to be combined with the LMA-adjusted
	private Rectangle         woi =             null; // in tiles
	private int               width;
	private double []         last_rms =        null; // {rms, rms_pure}, matching this.vector
	private double []         good_or_bad_rms = null; // just for diagnostics, to read last (failed) rms
	private double []         initial_rms =     null; // {rms, rms_pure}, first-calcualted rms
	private double []         parameters_vector = null;
//	private double []         x_vector =        null; // not used. Save total weight
	private double []         y_vector =        null;
	private double [][]       tile_centers =    null;
	private double            weight = 0; // total weight
	private double            weight_pure = 0;
	private double []         weights; // normalized so sum is 1.0 for all - samples and extra regularization
//	private double            pure_weight; // weight of samples only	
	private double []         last_ymfx =       null;
	private double [][]       last_jt =         null;
	private boolean           origin_center =   false; // true - origin in overlap center, false - top left corner (as it was)
	private double []         origin =          null; // in pixels either {0,0} for top-left or center of the woi
	private double            center_radius =   0;    // center zone radius (in tiles) that has limited standard deviation
	public  int               num_good_tiles = 0;
	public boolean            last3only = true;
	public OrthoPairLMA (boolean origin_center) {
		this.origin_center = origin_center;
	}
	public double getCenterRadius() {
		return center_radius;
	}
	
	public void debugStateImage(String title) {
		String [] titles = {"Yhor", "Yvert", "fXhor", "fXvert", "Dhor", "Dvert","DWhor", "DWvert", "Whor", "Wvert"};
		double [][] dbg_img = new double[titles.length][N];
		double [] fx = getFxDerivs(
				parameters_vector, // double []         vector,
				null,           // final double [][] jt, // should be null or initialized with [vector.length][]
				0); // debug_level);      // final int         debug_level)

		for (int i = 0; i < N; i++) {
			dbg_img[0][i] = y_vector[2*i + 0];
			dbg_img[1][i] = y_vector[2*i + 1];
			dbg_img[2][i] = fx[2*i + 0];
			dbg_img[3][i] = fx[2*i + 1];
			dbg_img[4][i] = y_vector[2*i + 0]-fx[2*i + 0];
			dbg_img[5][i] = y_vector[2*i + 1]-fx[2*i + 1];
			if (last_ymfx != null) {
				dbg_img[6][i] = last_ymfx[2*i + 0];
				dbg_img[7][i] = last_ymfx[2*i + 1];
			}
			dbg_img[8][i] = weights[2*i + 0];
			dbg_img[9][i] = weights[2*i + 1];
		}
		ShowDoubleFloatArrays.showArrays( 
				dbg_img,
				woi.width,
				woi.height,
				true,
				title,
				titles);
	}
	
	public int prepareLMA(
			// will always calculate relative affine, starting with unity
			int           width, // tilesX
			double [][]   vector_XYS,    // optical flow X,Y, confidence obtained from the correlate2DIterate()
			double [][]   centers,       // tile centers (in pixels)
			double []     weights_extra, // optional, may be null 
			boolean       first_run,
			int           min_good_tiles,
			double        max_std,      // maximal standard deviation to limit center area  
			double        min_std_rad,  // minimal radius of the central area (if less - fail)
			double        tile_rad,
			int           min_tiles_rad,
			
			double [][]   src_affine,
			double        pull_skew,        // ~rotation, = 0 fraction of the total weight == 1
			double        pull_tilt,        // > 0
			double        pull_scale,       // = 0
			
			final int     debug_level) {
		aff = new double [][] {src_affine[0].clone(),src_affine[1].clone()};
		tile_centers = centers;
		this.width = width;
		int height = vector_XYS.length / width;
		int min_x = width, min_y=height, max_x = -1, max_y=-1;
		num_good_tiles = 0;
		for (int tile = 0; tile < vector_XYS.length; tile++) if ((vector_XYS[tile] !=null)) {
			int tileX = tile % width;
			int tileY = tile / width;
			if (tileX < min_x) min_x = tileX;
			if (tileX > max_x) max_x = tileX;
			if (tileY < min_y) min_y = tileY;
			if (tileY > max_y) max_y = tileY;
			num_good_tiles++;
		}
		woi = new Rectangle (min_x, min_y, max_x - min_x + 1, max_y - min_y + 1);
		if (num_good_tiles < min_good_tiles) {
			return num_good_tiles;
		}
		origin = new double[2];
		if (origin_center) {
			origin = new double [] {
					(woi.x + 0.5 * woi.width)  * GPUTileProcessor.DTT_SIZE,
					(woi.y + 0.5 * woi.height) * GPUTileProcessor.DTT_SIZE};
		}
		N = woi.width * woi.height;
		parameters_vector = new double [] {1,0,0,1,0,0};
		if (min_std_rad > 0) {
//			int         min_tiles = 4;
			getCenterRadius(
					max_std,       // final double      max_std,      // maximal standard deviation to limit center area
					tile_rad,      // min_std_rad,   // final double      min_radius,
					min_tiles_rad, // min_tiles,     // final int         min_tiles,
					vector_XYS,    // final double [][] vector_XYS,
					weights_extra, // final double []   weights_extra, // null or additional weights (such as elevation-based)
					centers);      // final double [][] centers)
//			if (getCenterRadius() < min_std_rad) { // only after last pass of LMA!
//				num_good_tiles = 0;
//				return num_good_tiles;
//			}
			weights_extra = applyRadialWeights( // uses this.center_radius;
					vector_XYS,    // final double [][] vector_XYS,
					weights_extra, // final double []   weights_extra,
					centers);      // final double [][] centers)	
		}
		
		setSamplesWeightsYCenters(
				vector_XYS,
				weights_extra, // null or additional weights (such as elevation-based)
				centers,
				pull_skew,     // double        pull_skew,        // ~rotation, = 0 fraction of the total weight == 1
				pull_tilt,     // double        pull_tilt,        // > 0
				pull_scale);   // double        pull_scale);,       // = 0

		last_jt = new double [parameters_vector.length][];
		if (debug_level > 1) {
			System.out.println("prepareLMA() 1");
		}
		
		last_rms = new double [2];
		initial_rms = last_rms.clone();
		good_or_bad_rms = this.last_rms.clone();
		return num_good_tiles;
	}
	
	public double [][] getAffine(){
		double [][] affine= new double [][] {
			{parameters_vector[0],parameters_vector[1],parameters_vector[4]},
			{parameters_vector[2],parameters_vector[3],parameters_vector[5]}};
		affine[0][2] -= (affine[0][0] -1) * origin[0] + affine[0][1]*origin[1];  	
		affine[1][2] -= affine[1][0]*origin[0] + (affine[1][1] - 1) * origin[1];  	
		return affine;	
	}
	public double [] getRms() {
		return last_rms;
		
	}
	public double [] getInitialRms() {
		return initial_rms;
	}
	public double getWeight() { // no used
		return weight;
	}
	public double [][] getLastJtJ(){
		return getWJtJlambda(
				0.0, // lambda, // *10, // temporary
				this.last_jt);
	}
	//getWJtJlambda(
//	lambda, // *10, // temporary
//	this.last_jt)

	private double [] applyRadialWeights( // uses this.center_radius;
			final double [][] vector_XYS,
			final double []   weights_extra,
			final double [][] centers) { // null or additional weights (such as elevation-based)
		return applyRadialWeights(
				center_radius, // final double      radius_tiles,
				vector_XYS,    // final double [][] vector_XYS,
				weights_extra, // final double []   weights_extra,  // null or additional weights (such as elevation-based)
				centers);      // final double [][] centers);
	}
	private double [] applyRadialWeights(
			final double      radius_tiles,
			final double [][] vector_XYS,
			final double []   weights_extra,  // null or additional weights (such as elevation-based)
			final double [][] centers) {
		if (Double.isInfinite(radius_tiles)) {
			return weights_extra; // may be null - should be OK
		}
		final double radius_pix = radius_tiles * GPUTileProcessor.DTT_SIZE;
		final double radius_pix2 = radius_pix*radius_pix;
		final double [] weights = new double [vector_XYS.length];
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
						int tileX = iTile % woi.width + woi.x;
						int tileY = iTile / woi.width + woi.y;
						int aTile = tileY * width + tileX;
						if ((vector_XYS[aTile] != null) && (centers[aTile] != null)) {
							double w = vector_XYS[aTile][2];
							if (weights_extra != null) w *= weights_extra[aTile]; 
							if (Double.isNaN(w)) w = 0;
							double dx = centers[aTile][0] - origin[0];
							double dy = centers[aTile][1] - origin[1];
							double r_pix2 = dx*dx+dy*dy; // radius in pixels, squared
							if (r_pix2 < 4*radius_pix2) {
								double r_pix = Math.sqrt(r_pix2); // radius in pixels
								double wr = 0.5*(1.0 + Math.cos(0.5*Math.PI*r_pix/radius_pix));
								weights[aTile] = w*wr;
							} else {
								weights[aTile] = 0;
							}
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return weights;
	}
	
	
	
	/**
	 * Calculate center radius (in tiles) to use for LMA when initial rotation or
	 * scaling is inaccurate so the peripheral areas do not fit into correlation range.
	 * In that case try to adjust only the central area first, then increase that area
	 * in next iterations. 
	 * @param max_std       maximal standard deviation (average for X and Y) inside center area
	 * @param min_radius    minimal central zone radius that has to have >= min_tiles.
	 *                      Should be increased (with min_tiles) for high resolution images.
	 *                      Also there can be a bush right in the center - need to handle it too.
	 *                      Maybe for the final it should be a fraction of the minimal overlap
	 *                      dimension? 
	 * @param min_tiles     minimal tiles in the center zone
	 * @param vector_XYS    2D correlation-measured X,Y, and strength
	 * @param weights_extra null or optional additional weights of the samples
	 * @param centers       per tile (matching vector_XYS and weights_extra) tile centers in pixels
	 * @return maximal radius from the center (in pixels) where standard deviation of the inside samples is
	 *         below max_std. Returns Double.POSITIVE_INFINITY if std is not reached
	 */
	public double getCenterRadius(
	        final double      max_std,      // maximal standard deviation to limit center area  
//	        final double      min_std_rad,  // minimal radius of the central area (if less - fail)
	        final double      min_radius,
	        final int         min_tiles,
			final double [][] vector_XYS,
			final double []   weights_extra, // null or additional weights (such as elevation-based)
			final double [][] centers) {
		boolean dbg = false;
		final int rad_max = Math.max(woi.width,woi.height)/2 +1; 
		final int rad_length = rad_max +1; 
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		final AtomicInteger ati = new AtomicInteger(0);
		final double [][] s0_arr =  new double [threads.length][rad_length];
		final double [][] sx_arr =  new double [threads.length][rad_length];
		final double [][] sx2_arr = new double [threads.length][rad_length];
		final double [][] sy_arr =  new double [threads.length][rad_length];
		final double [][] sy2_arr = new double [threads.length][rad_length];
		final int    [][] sn_arr =  new int [threads.length][rad_length];
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					int thread_num = ati.getAndIncrement();
					for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
						int tileX = iTile % woi.width + woi.x;
						int tileY = iTile / woi.width + woi.y;
						int aTile = tileY * width + tileX;
						if ((vector_XYS[aTile] != null) && (centers[aTile] != null)) {
							double w = vector_XYS[aTile][2];
							if (weights_extra != null) w *= weights_extra[aTile]; 
							if (Double.isNaN(w)) w = 0;
							double dx = centers[aTile][0] - origin[0];
							double dy = centers[aTile][1] - origin[1];
							double r_t = Math.sqrt(dx*dx+dy*dy)/GPUTileProcessor.DTT_SIZE; // radius in tiles
							int irt = (int) Math.round(r_t);
							if (irt < rad_length) {
								double fx = vector_XYS[aTile][0];
								double fy = vector_XYS[aTile][1];
								sn_arr [thread_num][irt] += 1;
								s0_arr [thread_num][irt] += w;
								sx_arr [thread_num][irt] += w * fx;
								sx2_arr[thread_num][irt] += w * fx*fx;
								sy_arr [thread_num][irt] += w * fy;
								sy2_arr[thread_num][irt] += w * fy * fy;
							}
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		final double [] s0 =  new double [rad_length];
		final double [] sx =  new double [rad_length];
		final double [] sx2 = new double [rad_length];
		final double [] sy =  new double [rad_length];
		final double [] sy2 = new double [rad_length];
		final int []    sn =  new int [rad_length];
		ai.set(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iRad = ai.getAndIncrement(); iRad < rad_length; iRad = ai.getAndIncrement()) {
						for (int i = 0; i < threads.length; i++) {
							s0 [iRad]+=s0_arr [i][iRad];
							sx [iRad]+=sx_arr [i][iRad];
							sx2[iRad]+=sx2_arr[i][iRad];
							sy [iRad]+=sy_arr [i][iRad];
							sy2[iRad]+=sy2_arr[i][iRad];
							sn [iRad]+=sn_arr [i][iRad];
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		double sc0=0, scx=0,scx2=0,scy=0,scy2=0,std_prev=0, std=0;
		int    scn=0;
		int    scnc = 0;
		center_radius = Double.POSITIVE_INFINITY; 
		for (int iRad=0; iRad < rad_length; iRad++) {
			sc0+= s0 [iRad];
			scx+= sx [iRad];
			scx2+=sx2[iRad];
			scy+= sy [iRad];
			scy2+=sy2[iRad];
			scn+= sn [iRad];
			if (iRad <= ((int) Math.round(min_radius))) {
				scnc = scn;
			}
			if ((scn > min_tiles) && (sc0 > 0)) {// for one tile gets sqcrt() of a small negative error
				std = Math.sqrt((scx2*sc0 - scx*scx + scy2*sc0 - scy*scy)/2)/sc0;
			}
			if ((scn >= min_tiles) && (std >= max_std)) {
				if (iRad==0) {
					center_radius = 0;
					break; // probably can not happen
				}
				center_radius = iRad - (std - max_std) / (std - std_prev);
				break;
			}
			std_prev = std;
		}
	
		if (dbg) {
			final double [] weights_center = applyRadialWeights( // uses this.center_radius;
					vector_XYS, // final double [][] vector_XYS,
					weights_extra, // final double []   weights_extra,
					centers); // final double [][] centers);
			
			String [] titles = {"fx","fy","wc","vw","w","cent-x","cent-y","dx","dy","r_t","irt"};
			final double [][] dbg_img = new double [titles.length][woi.width*woi.height];
			for (int i = 0; i < dbg_img.length; i++) {
				Arrays.fill(dbg_img[i], Double.NaN);
			}
			for (int ithread = 0; ithread < threads.length; ithread++) {
				threads[ithread] = new Thread() {
					public void run() {
						for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
							int tileX = iTile % woi.width + woi.x;
							int tileY = iTile / woi.width + woi.y;
							int aTile = tileY * width + tileX;
							dbg_img[ 2][iTile] = weights_center[aTile];
							if ((vector_XYS[aTile] != null) && (centers[aTile] != null)) {
								double w = vector_XYS[aTile][2];
								if (weights_extra != null) w *= weights_extra[aTile]; 
								if (Double.isNaN(w)) w = 0;
								double dx = centers[aTile][0] - origin[0];
								double dy = centers[aTile][1] - origin[1];
								double r_t = Math.sqrt(dx*dx+dy*dy)/GPUTileProcessor.DTT_SIZE; // radius in tiles
								int irt = (int) Math.round(r_t);
								dbg_img[ 0][iTile] = vector_XYS[aTile][0];
								dbg_img[ 1][iTile] = vector_XYS[aTile][1];
								dbg_img[ 3][iTile] = vector_XYS[aTile][2];
								dbg_img[ 4][iTile] = w;
								dbg_img[ 5][iTile] = centers[aTile][0];
								dbg_img[ 6][iTile] = centers[aTile][1];
								dbg_img[ 7][iTile] = dx;
								dbg_img[ 8][iTile] = dy;
								dbg_img[ 9][iTile] = r_t;
								dbg_img[10][iTile] = irt;
							}
						}
					}
				};
			}		      
			ImageDtt.startAndJoin(threads);
			ai.set(0);
			ShowDoubleFloatArrays.showArrays(
					dbg_img,
					woi.width,
					woi.height,
					true,
					"getCenterRadius",
					titles);
		}
		
		if ((min_tiles > 0) && (min_radius > 0) && (scnc < min_tiles)) {
			center_radius = 0;
		}
		return center_radius;
	}
	
	
	
	private void setSamplesWeightsYCenters(
			final double [][] vector_XYS,
			final double []   weights_extra, // null or additional weights (such as elevation-based)
			final double [][] centers,
			final double      pull_skew,        // ~rotation, = 0 fraction of the total weight == 1
			final double      pull_tilt,        // > 0
			final double      pull_scale){        // = 0
		//num_components 2 - old, 3 - with disparity
		this.weights = new double [2*N+3]; // same for X and Y
		y_vector = new double[2*N+3];
		tile_centers = new double [N][];
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		final AtomicInteger ati = new AtomicInteger(0);
		final double [] sw_arr = new double [threads.length];
//		double sum_weights;
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					int thread_num = ati.getAndIncrement();
					for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
						int tileX = iTile % woi.width + woi.x;
						int tileY = iTile / woi.width + woi.y;
						int aTile = tileY * width + tileX;
						if ((vector_XYS[aTile] != null) && (centers[aTile] != null)) {
							double w = vector_XYS[aTile][2];
							if (weights_extra != null) w *= weights_extra[aTile]; 
							if (Double.isNaN(w)) w = 0;
							sw_arr[thread_num] += 2*w;
							weights [2*iTile + 0] = w;
							weights [2*iTile + 1] = w;
							y_vector[2*iTile + 0] = vector_XYS[aTile][0];
							y_vector[2*iTile + 1] = vector_XYS[aTile][1];
							tile_centers[iTile] = centers[aTile];
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		weight = 0.0;
		for (double w:sw_arr) {
			weight += w;
		}
		if (weight <= 1E-8) {
			System.out.println("!!!!!! setSamplesWeights(): sum_weights="+weight+" <= 1E-8");
		}
		ai.set(0);
		double sum_pure = pull_skew + pull_tilt + pull_scale;
		double scale_pure = 1.0;
		if (sum_pure > 1.0) {
			System.out.println("sum_pure="+sum_pure+" > 1.0, reducing each component to make it 0.5");
			scale_pure = 0.5/sum_pure;
		}
		weight_pure = 1.0 - scale_pure * sum_pure;
		final double s = weight_pure/weight; // Was 0.5 - already taken care of
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
						weights[2 * iTile    ] *= s;
						weights[2 * iTile + 1] *= s;
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		weights[2 * N + 0] = pull_skew *  scale_pure;
		weights[2 * N + 1] = pull_tilt *  scale_pure;
		weights[2 * N + 2] = pull_scale * scale_pure;
	}

	public int runLma( // <0 - failed, >=0 iteration number (1 - immediately)
			double lambda,           // 0.1
			double lambda_scale_good,// 0.5
			double lambda_scale_bad, // 8.0
			double lambda_max,       // 100
			double rms_diff,         // 0.001
			int    num_iter,         // 20
			boolean last_run,
			String dbg_prefix,
			int    debug_level)
	{
		boolean [] rslt = {false,false};
		this.last_rms = null; // remove?
		int iter = 0;
		if (dbg_prefix != null) {
			 debugStateImage(dbg_prefix+"-initial");
		}
		for (iter = 0; iter < num_iter; iter++) {
			rslt =  lmaStep(
					lambda,
					rms_diff,
					debug_level);
			if (dbg_prefix != null) {
				 debugStateImage(dbg_prefix+"-step_"+iter);
			}
			
			if (rslt == null) {
				return -1; // false; // need to check
			}
			if (debug_level > 1) {
				System.out.println("LMA step "+iter+": {"+rslt[0]+","+rslt[1]+"} full RMS= "+good_or_bad_rms[0]+
						" ("+initial_rms[0]+"), pure RMS="+good_or_bad_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
			}
			if (rslt[1]) {
				break;
			}
			if (rslt[0]) { // good
				lambda *= lambda_scale_good;
			} else {
				lambda *= lambda_scale_bad;
				if (lambda > lambda_max) {
					break; // not used in lwir
				}
			}
		}
		if (rslt[0]) { // better
			if (iter >= num_iter) { // better, but num tries exceeded
				if (debug_level > 1) System.out.println("Step "+iter+": Improved, but number of steps exceeded maximal");
			} else {
				if (debug_level > 1) System.out.println("Step "+iter+": LMA: Success");
			}

		} else { // improved over initial ?
			if (last_rms[0] < initial_rms[0]) { // NaN
				rslt[0] = true;
				if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge, but result improved over initial");
			} else {
				if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge");
			}
		}
		boolean show_intermediate = true;
		if (show_intermediate && (debug_level > 0)) {
			System.out.println("LMA: full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
		}
		if (debug_level > 2){ 
			System.out.println("iteration="+iter);
		}
		if (debug_level > 0) {
			if ((debug_level > 1) ||  last_run) { // (iter == 1) || last_run) {
				if (!show_intermediate) {
					System.out.println("LMA: iter="+iter+",   full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
				}
			}
		}
		if ((debug_level > -2) && !rslt[0]) { // failed
			if ((debug_level > 1) || (iter == 1) || last_run) {
				System.out.println("LMA failed on iteration = "+iter);
			}
			System.out.println();
		}

		return rslt[0]? iter : -1;
	}
	
	
	private boolean [] lmaStep(
			double lambda,
			double rms_diff,
			int debug_level) {
		boolean [] rslt = {false,false};
		// maybe the following if() branch is not needed - already done in prepareLMA !
		if (this.last_rms == null) { //first time, need to calculate all (vector is valid)
			last_rms = new double[2];
			if (debug_level > 1) {
				System.out.println("lmaStep(): first step");
			}
			double [] fx = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);      // final int         debug_level)
			last_ymfx = getYminusFxWeighted(
					fx, // final double []   fx,
					last_rms); // final double []   rms_fp // null or [2]
			this.initial_rms = this.last_rms.clone();
			this.good_or_bad_rms = this.last_rms.clone();

			if (debug_level > -1) { // temporary
				/*
				dbgYminusFxWeight(
						this.last_ymfx,
						this.weights,
						"Initial_y-fX_after_moving_objects");
                */
			}
			if (last_ymfx == null) {
				return null; // need to re-init/restart LMA
			}
			// TODO: Restore/implement
			if (debug_level > 3) {
				double    delta = 1E-5;
			 	double delta_err=compareJT(
			 			parameters_vector, // double [] vector,
						delta,             // double    delta,
						last3only);        // boolean   last3only); // do not process samples - they are tested before
				System.out.println("\nMaximal error = "+delta_err);
				
				/*
				 dbgJacobians(
							corr_vector, // GeometryCorrection.CorrVector corr_vector,
							1E-5, // double delta,
							true); //boolean graphic)
				*/
			}
		}
		if (debug_level > 0) {
			double    delta = 1E-5;
		 	double delta_err=compareJT(
		 			parameters_vector, // double [] vector,
					delta,             // double    delta,
					last3only);        // boolean   last3only); // do not process samples - they are tested before
			System.out.println("\nMaximal error = "+delta_err);
			
			/*
			 dbgJacobians(
						corr_vector, // GeometryCorrection.CorrVector corr_vector,
						1E-5, // double delta,
						true); //boolean graphic)
			*/
		}
		
		Matrix y_minus_fx_weighted = new Matrix(this.last_ymfx, this.last_ymfx.length);

		Matrix wjtjlambda = new Matrix(getWJtJlambda(
				lambda, // *10, // temporary
				this.last_jt)); // double [][] jt)
		
		if (debug_level>2) {
			System.out.println("JtJ + lambda*diag(JtJ");
			wjtjlambda.print(18, 6);
		}
		Matrix jtjl_inv = null;
		try {
			jtjl_inv = wjtjlambda.inverse(); // check for errors
		} catch (RuntimeException e) {
			rslt[1] = true;
			if (debug_level > 0) {
				System.out.println("Singular Matrix!");
			}

			return rslt;
		}
		if (debug_level>2) {
			System.out.println("(JtJ + lambda*diag(JtJ).inv()");
			jtjl_inv.print(18, 6);
		}
//last_jt has NaNs
		Matrix jty = (new Matrix(this.last_jt)).times(y_minus_fx_weighted);
		if (debug_level>2) {
			System.out.println("Jt * (y-fx)");
			jty.print(18, 6);
		}
		
		
		Matrix mdelta = jtjl_inv.times(jty);
		if (debug_level>2) {
			System.out.println("mdelta");
			mdelta.print(18, 6);
		}

		double scale = 1.0;
		double []  delta =      mdelta.getColumnPackedCopy();
		double []  new_vector = parameters_vector.clone();
		for (int i = 0; i < parameters_vector.length; i++) {
			new_vector[i] += scale * delta[i];
		}
		
		double [] fx = getFxDerivs(
				new_vector, // double []         vector,
				last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
				debug_level);      // final int         debug_level)
		double [] rms = new double[2];
		last_ymfx = getYminusFxWeighted(
//				vector_XYS, // final double [][] vector_XYS,
				fx, // final double []   fx,
				rms); // final double []   rms_fp // null or [2]
		if (debug_level > 2) {
			/*
			dbgYminusFx(this.last_ymfx, "next y-fX");
			dbgXY(new_vector, "XY-correction");
			*/
		}

		if (last_ymfx == null) {
			return null; // need to re-init/restart LMA
		}

		this.good_or_bad_rms = rms.clone();
		if (rms[0] < this.last_rms[0]) { // improved
			rslt[0] = true;
			rslt[1] = rms[0] >=(this.last_rms[0] * (1.0 - rms_diff));
			this.last_rms = rms.clone();

			this.parameters_vector = new_vector.clone();
			if (debug_level > 2) {
				// print vectors in some format
				/*
				System.out.print("delta: "+corr_delta.toString()+"\n");
				System.out.print("New vector: "+new_vector.toString()+"\n");
				System.out.println();
				*/
			}
		} else { // worsened
			rslt[0] = false;
			rslt[1] = false; // do not know, caller will decide
			// restore state
			fx = getFxDerivs(
					parameters_vector, // double []         vector,
					last_jt,           // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);      // final int         debug_level)
			last_ymfx = getYminusFxWeighted(
					fx, // final double []   fx,
					this.last_rms); // final double []   rms_fp // null or [2]
			if (last_ymfx == null) {
				return null; // need to re-init/restart LMA
			}
			if (debug_level > 2) {
				/*
				 dbgJacobians(
							corr_vector, // GeometryCorrection.CorrVector corr_vector,
							1E-5, // double delta,
							true); //boolean graphic)
							*/
			}
		}
		return rslt;
	}
	
	
	private double [][] getWJtJlambda(
			final double      lambda,
			final double [][] jt)
	{
		final int num_pars = jt.length;
		final int num_pars2 = num_pars * num_pars;
		final int nup_points = jt[0].length;
		final double [][] wjtjl = new double [num_pars][num_pars];
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int indx = ai.getAndIncrement(); indx < num_pars2; indx = ai.getAndIncrement()) {
						int i = indx / num_pars;
						int j = indx % num_pars;
						if (j >= i) {
							double d = 0.0;
							for (int k = 0; k < nup_points; k++) {
								if (jt[i][k] != 0) {
									d+=0;
								}
								d += weights[k]*jt[i][k]*jt[j][k];
							}
							wjtjl[i][j] = d;
							if (i == j) {
								wjtjl[i][j] += d * lambda;
							} else {
								wjtjl[j][i] = d;
							}
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return wjtjl;
	}
	
	
	
	private double [] getFxDerivs(
			final double []         vector,
			final double [][] jt, // should be null or initialized with [vector.length][]
			final int         debug_level)
	{
		final double [] fx = new double [weights.length]; // weights.length]; : weights.length :
		if (jt != null) {
			for (int i = 0; i < jt.length; i++) {
				jt[i] = new double [weights.length]; // weights.length];
			}
		}
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) if (tile_centers[iTile] !=null) {
						double x = tile_centers[iTile][0] - origin[0];
						double y = tile_centers[iTile][1] - origin[1];
						double vx = (vector[0] - 1.0) * x +  vector[1]        * y + vector[4];
						double vy =  vector[2] *        x + (vector[3] - 1.0) * y + vector[5];
						fx[2 * iTile + 0] = vx;
						fx[2 * iTile + 1] = vy;
						if (jt != null) {
							jt[0][2 * iTile + 0] = x;
							jt[1][2 * iTile + 0] = y;
							jt[2][2 * iTile + 0] = 0.0;
							jt[3][2 * iTile + 0] = 0.0;
							jt[4][2 * iTile + 0] = 1.0;
							jt[5][2 * iTile + 0] = 0.0;
							jt[0][2 * iTile + 1] = 0.0;
							jt[1][2 * iTile + 1] = 0.0;
							jt[2][2 * iTile + 1] = x;
							jt[3][2 * iTile + 1] = y;
							jt[4][2 * iTile + 1] = 0.0;
							jt[5][2 * iTile + 1] = 1.0;
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		double [][] A = {
				{vector[0]*aff[0][0] + vector[1]*aff[1][0], vector[0]*aff[0][1] + vector[1]*aff[1][1]},
				{vector[2]*aff[0][0] + vector[3]*aff[1][0], vector[2]*aff[0][1] + vector[3]*aff[1][1]}};
		double aa0 = A[0][0]*A[0][1]+A[1][0]*A[1][1];
		double aa1 = 0.5*(A[0][0]*A[0][0] + A[1][0]*A[1][0]);
		double aa2 = 0.5*(A[0][1]*A[0][1] + A[1][1]*A[1][1]);
		fx[2 * N  + 0] = 1000*aa0;
		fx[2 * N  + 1] = 1000*(aa1-aa2);
		fx[2 * N  + 2] = 1000*(aa1+aa2-1.0);
		if (jt != null) {
			double [][][] dA_dp = {
					{{aff[0][0], aff[0][1]},{        0,         0}},
					{{aff[1][0], aff[1][1]},{        0,         0}},
					{{        0,         0},{aff[0][0], aff[0][1]}},
					{{        0,         0},{aff[1][0], aff[1][1]}}};
			double [] daa0_dp = new double[4];
			double [] daa1_dp = new double[4];
			double [] daa2_dp = new double[4];
			for (int i = 0; i <4; i++) {
				daa0_dp[i] = (dA_dp[i][0][0]*A[0][1] + A[0][0]*dA_dp[i][0][1]) + (dA_dp[i][1][0]*A[1][1]+A[1][0]*dA_dp[i][1][1]);
				daa1_dp[i] = (dA_dp[i][0][0]*A[0][0]) + (dA_dp[i][1][0]*A[1][0]);
				daa2_dp[i] = (dA_dp[i][0][1]*A[0][1]) + (dA_dp[i][1][1]*A[1][1]);
				jt[i][2 * N + 0] = 1000*(daa0_dp[i]);
				jt[i][2 * N + 1] = 1000*(daa1_dp[i]-daa2_dp[i]);
				jt[i][2 * N + 2] = 1000*(daa1_dp[i]+daa2_dp[i]);
			}
		}
		return fx;
	}	
	
	private double [][] getFxDerivsDelta(
			double []         vector,
			final double      delta,
			final int         debug_level) {
		double [][] jt =  new double [vector.length][weights.length];
		for (int nv = 0; nv < vector.length; nv++) {
			double [] vpm = vector.clone();
			vpm[nv]+= 0.5*delta;
			double [] fx_p =  getFxDerivs(
					vpm,
					null, // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);
			vpm[nv]-= delta;
			double [] fx_m =  getFxDerivs(
					vpm,
					null, // final double [][] jt, // should be null or initialized with [vector.length][]
					debug_level);
			for (int i = 0; i < weights.length; i++) if (weights[i] > 0) {
				jt[nv][i] = (fx_p[i]-fx_m[i])/delta;
			}
		}
		return jt;
	}

	
	
	
	private double [] getYminusFxWeighted(
			final double []   fx,
			final double []   rms_fp // null or [2]
			) {
		if (thread_invariant) {
			return getYminusFxWeightedInvariant(fx,rms_fp); // null or [>0]
		} else {
			return getYminusFxWeightedFast     (fx,rms_fp); // null or [>0]
		}
	}	
	
	private double [] getYminusFxWeightedInvariant(
			final double []   fx,
			final double []   rms_fp // null or [2]
			) {
		final Thread[]      threads =     ImageDtt.newThreadArray();
		final AtomicInteger ai =          new AtomicInteger(0);
		final double []     wymfw =       new double [fx.length];
		double s_rms; 
		final int num_samples = 2 * N;
		final double [] l2_arr = new double [num_samples]; 
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
						double d = y_vector[i] - fx[i];
						double wd = d * weights[i];
						if (Double.isNaN(wd)) {
							System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
									", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
						}
						//double l2 = d * wd;
						l2_arr[i] = d * wd;
						wymfw[i] = wd;
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		s_rms = 0.0;
		for (double l2:l2_arr) {
			s_rms += l2;
		}
		double rms_pure = Math.sqrt(s_rms/weight_pure);
		for (int i = num_samples; i < fx.length; i++) {
			double d = y_vector[i] - fx[i];
			double wd = d * weights[i];
			wymfw[i] = wd;
			s_rms += d * wd;
			// num_samples
		}
		double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0; /pure_weight); they should be re-normalized after adding regularization
		if (rms_fp != null) {
			rms_fp[0] = rms;
			rms_fp[1] = rms_pure;
		}
		return wymfw;
	}
	private double [] getYminusFxWeightedFast(
			final double []   fx,
			final double []   rms_fp // null or [2]
			) {
		final int num_samples = 2 * N;
		final Thread[]      threads =     ImageDtt.newThreadArray();
		final AtomicInteger ai =          new AtomicInteger(0);
		final double []     wymfw =       new double [fx.length];
		final AtomicInteger ati = new AtomicInteger(0);
		final double [] l2_arr = new double [threads.length];
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					int thread_num = ati.getAndIncrement();
					for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
						double d = y_vector[i] - fx[i];
						double wd = d * weights[i];
						if (Double.isNaN(wd)) {
							System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
									", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
						}
						//double l2 = d * wd;
						l2_arr[thread_num] += d * wd;
						wymfw[i] = wd;
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		double s_rms = 0.0;
		for (double l2:l2_arr) {
			s_rms += l2;
		}
		double rms_pure = Math.sqrt(s_rms/weight_pure);
		for (int i = num_samples; i < fx.length; i++) {
			double d = y_vector[i] - fx[i];
			double wd = d * weights[i];
			wymfw[i] = wd;
			s_rms += d * wd;
		}
		double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0; /pure_weight); they should be re-normalized after adding regularization
		if (rms_fp != null) {
			rms_fp[0] = rms;
			rms_fp[1] = rms_pure;
		}
		return wymfw;
	}
	
 	private double compareJT(
			double [] vector,
			double    delta,
			boolean   last3only) { // do not process samples - they are tested before
		double []  errors=new double [vector.length];
		double [][] jt =  new double [vector.length][];
		System.out.print("Parameters vector = [");
		for (int i = 0; i < vector.length; i++) {
			System.out.print(vector[i]);
			if (i < (vector.length -1)) System.out.print(", ");
		}
		System.out.println("]");
		getFxDerivs(
				vector,
				jt, // final double [][] jt, // should be null or initialized with [vector.length][]
				1); // debug_level);
		double [][] jt_delta =  getFxDerivsDelta(
				vector, // double []         vector,
				delta, // final double      delta,
				-1); // final int         debug_level)
		int start_index = last3only? (weights.length-3) : 0;
		for (int n = start_index; n < weights.length; n++) if (weights[n] > 0) {
			System.out.print(String.format("%3d",n));
			for (int i = 0; i < vector.length; i++) {
				System.out.print(String.format("\t%12.9f",jt[i][n]));
			}			
			for (int i = 0; i < vector.length; i++) {
				System.out.print(String.format("\t%12.9f",jt_delta[i][n]));
			}			
			for (int i = 0; i < vector.length; i++) {
				System.out.print(String.format("\t%12.9f",jt[i][n]-jt_delta[i][n]));
			}			
			System.out.println();
			/*
			System.out.println(String.format(
					"%3d\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f\t%12.9f",
					n, jt[0][n], jt[1][n], jt[2][n], jt[3][n],
					jt_delta[0][n], jt_delta[1][n], jt_delta[2][n], jt_delta[3][n],
					jt[0][n]-jt_delta[0][n],jt[1][n]-jt_delta[1][n],jt[2][n]-jt_delta[2][n],jt[3][n]-jt_delta[3][n]));
					*/
			for (int i = 0; i < vector.length; i++) {
				errors[i] = Math.max(errors[i], jt[i][n]-jt_delta[i][n]);
			}
		}
		for (int i = 0; i < vector.length; i++) {
			System.out.print("\t\t");
		}			
		for (int i = 0; i < vector.length; i++) {
			System.out.print(String.format("\t%12.9f",errors[i]));
		}			
        /*
		System.out.println(String.format(
				"-\t-\t-\t-\t-\t-\t-\t-\t-\t%12.9f\t%12.9f\t%12.9f\t%12.9f",
				errors[0], errors[1], errors[2], errors[3]));
				*/
		double err=0;
		for (int i = 0; i < vector.length; i++) {
			err = Math.max(errors[i], err);
		}
		return err;
	}
	
	
	
}
