/**
 **
 ** TDCorrTile - Transform Domain 2D correlation tile
 **
 ** Copyright (C) 2023 Elphel, Inc.
 **
 ** -----------------------------------------------------------------------------**
 **
 **  StructureFromMotion.java is free software: you can redistribute it and/or modify
 **  it under the terms of the GNU General Public License as published by
 **  the Free Software Foundation, either version 3 of the License, or
 **  (at your option) any later version.
 **
 **  This program is distributed in the hope that it will be useful,
 **  but WITHOUT ANY WARRANTY; without even the implied warranty of
 **  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 **  GNU General Public License for more details.
 **
 **  You should have received a copy of the GNU General Public License
 **  along with this program.  If not, see <http://www.gnu.org/licenses/>.
 ** -----------------------------------------------------------------------------**
 **
 */

package com.elphel.imagej.tileprocessor;

import java.util.Arrays;
import java.util.concurrent.atomic.AtomicInteger;

import com.elphel.imagej.cameras.CLTParameters;
import com.elphel.imagej.gpu.GPUTileProcessor;
import com.elphel.imagej.gpu.GpuQuad;

public class TDCorrTile {
	double    weight;
	final double [] data;
	
	/**
	 * Constructor of an empty 2D correlation tile 
	 */
	public TDCorrTile () {
		weight = 0.0;
		data = new double [4*GPUTileProcessor.DTT_SIZE*GPUTileProcessor.DTT_SIZE];
	}

	/**
	 * Constructor from weight and double [] array
	 * @param weight tile weight
	 * @param data Transform-domain 2D correlation data as double[]
	 */
	public TDCorrTile (double weight, double [] data) {
		this.weight = weight;
		this.data = data;
	}
	
	/**
	 * Constructor from weight and float [] array
	 * @param weight tile weight
	 * @param data Transform-domain 2D correlation data as float[]
	 */
	public TDCorrTile (double weight, float [] data) {
		this.weight = weight;
		this.data = new double [data.length];
		for (int i = 0; i < data.length; i++) {
			this.data[i] = data[i];
		}
	}
	
	/**
	 * Get Transform-domain 2D correlation tile weight needed for fat zero in phase
	 * correlation and transform to the pixel domain.
	 * @return tile weight
	 */
	public double getWeight() {
		return weight;
	}

	/**
	 * Get a double representation of a single-tile Transform-domain 2D correlation data
	 * for transfer to the GPU memory. 
	 * @return double [] array with 2D representation of 2D correlation 
	 */
	public double [] getDoubleData() {
		return data;
	}
	
	/**
	 * Get a float representation of a single-tile Transform-domain 2D correlation data
	 * for transfer to the GPU memory. 
	 * @return float [] array with 2D representation of 2D correlation 
	 */
	public float [] getFloatData() {
		float [] fdata = new float [data.length];
		for (int i = 0; i < data.length;i++) {
			fdata[i] = (float) data[i];
		}
		return fdata;
	}
	
	/**
	 * Clone a single-tile Transform-domain 2D correlation data
	 */
	public TDCorrTile clone() {
		return new TDCorrTile(weight, data.clone());
		
	}

	/**
	 * Accumulate single-tile Transform-domain 2D correlation data, keep track of weights.
	 * Unity weight for source.
	 * @param tile Transform-domain representation of 2D correlation for a single tile.
	 */
	public void accumulate(TDCorrTile tile) {
		accumulate(tile, 1.0);
	}
	
	/**
	 * Accumulate single-tile Transform-domain 2D correlation data, keep track of weights.
	 * @param tile Transform-domain representation of 2D correlation for a single tile.
	 * @param src_weight
	 */
	public void accumulate(TDCorrTile tile, double src_weight) {
		double w1 = weight/(weight + src_weight * tile.weight);
		double w2 = 1.0 - w1;
		for (int i = 0; i < data.length; i++) {
			data[i] = w1 * data[i] + w2*tile.data[i];
		}
		weight += tile.weight * src_weight;
	}
	
	/**
	 * Scale Transform-domain 2D correlation tile
	 * @param weight scale amount
	 */
	public void scale (double weight) {
		this.weight *= weight;
	}

	/**
	 * Accumulate transform-domain 2D correlation tiles, keep track of the tile weights
	 * needed for fat zero adjustments during phase correlation and transform to pixel
	 * domain. Unity weight for source.
	 * @param dst destination array of tiles (some may be null). This array will be
	 *            modified.
	 * @param src destination array of tiles (some may be null) to add to destination.
	 */
	public static void accumulate (
			final TDCorrTile [] dst,
			final TDCorrTile [] src) {
		accumulate (dst, src, 1.0);
	}

	/**
	 * Accumulate transform-domain 2D correlation tiles, keep track of the tile weights
	 * needed for fat zero adjustments during phase correlation and transform to pixel
	 * domain.
	 * @param dst destination array of tiles (some may be null). This array will be
	 *            modified.
	 * @param src destination array of tiles (some may be null) to add to destination.
	 * @param src_weight scale source array during accumulation
	 */
	public static void accumulate (
			final TDCorrTile [] dst,
			final TDCorrTile [] src,
			final double src_weight) {
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < dst.length; iTile = ai.getAndIncrement()) if ((dst[iTile] != null) || (src[iTile] != null)) {
						if (dst[iTile] == null) {
							dst[iTile] = src[iTile].clone();
						} else if (src[iTile] != null) {
							dst[iTile].accumulate(src[iTile], src_weight);
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
	}
	
	/**
	 * Clone array of tiles (sparse) 
	 * @param tiles array of tiles to be cloned
	 * @return array of tiles clone
	 */
	public static TDCorrTile [] cloneTiles(
			final TDCorrTile [] tiles) {
		final TDCorrTile [] rslt = new TDCorrTile [tiles.length];
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
						rslt[iTile] = tiles[iTile].clone();
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return rslt;
	}
	
	/**
	 * Get number of defined (non-null) tiles
	 * @param tiles sparse array of tiles
	 * @param indices - null or int [tiles.length] that will be set to have unique
	 *                  indices for non-null tiles. Indices for non-null tiles will not be modified 
	 * @return number of non-null tiles
	 */
	public static int getNumTiles(
			final TDCorrTile [] tiles,
			final int [] indices) {
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		final AtomicInteger anum_tiles = new AtomicInteger(0);
		if (indices != null) {
			for (int ithread = 0; ithread < threads.length; ithread++) {
				threads[ithread] = new Thread() {
					public void run() {
						for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
							indices[iTile] = anum_tiles.getAndIncrement();
						}
					}
				};
			}
		} else {
			for (int ithread = 0; ithread < threads.length; ithread++) {
				threads[ithread] = new Thread() {
					public void run() {
						for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
							anum_tiles.getAndIncrement();
						}
					}
				};
			}
			
		}
		ImageDtt.startAndJoin(threads);
		return anum_tiles.get();
	}
	
	
	
	/**
	 * Prepare tile-mapped array of transform-domain 2D correlations from the center tile
	 * and eight of its immediate neighbors
	 * @param tiles tile-mapped array of transform-domain 2D correlations and tile weights
	 *              needed for correct application of fat zero for phase correlation and
	 *              transform to pixel domain.
	 * @param tilesX number of tiles in a scan row.
	 * @param neib_weights_od 2-element array of the relative (to the center tile) weights
	 *                        of neighbors in four orthogonal [0] and four diagonal [1]
	 *                        directions.
	 * @param corr_pd pixel-domain phase correlation corresponding to tiles to discard too
	 *                strong neighbors. They likely belong the foreground while this (weak)
	 *                is from the background. If null - do not filter by strength.                       
	 * @param neib_too_strong discard stronger neighbors from accumulating.                      
	 * @param process_all process tiles even if the center tile is null. False - process only
	 *                    the tiles around non-null ones.           
	 * @return tile-mapped array of the combined transform-domain 2D correlations.
	 */
	public static TDCorrTile [] calcNeibs(
			TDCorrTile [] tiles,
			final int     tilesX,
			double []     neib_weights_od, // {orhto, diag}
			double [][]   corr_pd,
			double        neib_too_strong,
			final boolean process_all			
			) {
		final int tilesY = tiles.length / tilesX;
		final TDCorrTile [] rslt = new TDCorrTile [tiles.length];
		final double [] weights = {
				neib_weights_od[0], neib_weights_od[1],
				neib_weights_od[0], neib_weights_od[1],
				neib_weights_od[0], neib_weights_od[1],
				neib_weights_od[0], neib_weights_od[1]};

		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					TileNeibs tn = new TileNeibs(tilesX,tilesY);
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (process_all || (tiles[iTile] != null)) {
						if (tiles[iTile] != null) {
							rslt[iTile] = tiles[iTile].clone();
						}
						for (int dir = 0; dir < tn.dirs; dir++) {
							int iTile1 = tn.getNeibIndex(iTile, dir);
							if ((iTile1 >=0) && (tiles[iTile1] != null)) {
								if ((corr_pd != null) && (corr_pd[iTile1] != null)) {
									boolean is_weak = true;
									for (double s:corr_pd[iTile1]) if (s >= neib_too_strong) {
										is_weak = false;
										break;
									}
									if (!is_weak) {
										continue; // skip this strong neighbor from accumulating
									}
								}
								if (rslt[iTile] == null) {
									rslt[iTile] = tiles[iTile1].clone();
									rslt[iTile].scale(weights[dir]);
								} else {
									rslt[iTile].accumulate(tiles[iTile1], weights[dir]);
								}
							}
						}
					}
				}
			};
		}
		ImageDtt.startAndJoin(threads);
		return rslt;
	}
	
	/**
	 * Average neighbors in TD primarily for rectilinear image matching 
	 * @param tiles tile-mapped array of transform-domain 2D correlations and tile weights
	 *              needed for correct application of fat zero for phase correlation and
	 *              transform to pixel domain.
	 * @param tilesX number of tiles in a scan row.
	 * @param radius radius (in X and Y direction) for averaging neighbors using cosine
	 *               for weights. Radius corresponds to PI/2, accumulation happens for
	 *               integer values
	 * @param process_all process tiles even if the center tile is null. False - process only
	 *                    the tiles around non-null ones.           
	 * @return same format as input tiles
	 */
	public static TDCorrTile [] calcNeibs(
			TDCorrTile [] tiles,
			final int     tilesX,
			final double  radius,
			final boolean process_all			
			) {
		final int tilesY = tiles.length / tilesX;
		final int irad = (int) Math.floor(radius);
		final int size = 2*irad+1;
		final double [][] wnd = new double [size][size];
		for (int i = 0; i < size; i++) {
			double wy = Math.cos((i-irad)*0.5*Math.PI/radius);
			for (int j = 0; j < size; j++) {
				wnd[i][j] =wy*Math.cos((j-irad)*0.5*Math.PI/radius);
			}
		}
		final TDCorrTile [] rslt = new TDCorrTile [tiles.length];
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					TileNeibs tn = new TileNeibs(tilesX,tilesY);
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement())  if (process_all || (tiles[iTile] != null)) {
						for (int dy = -irad; dy <= irad; dy++) {
							for (int dx = -irad; dx <= irad; dx++) {
								int iTile1 = tn.getNeibIndex(iTile, dx, dy);
								if ((iTile1 >= 0) && (tiles[iTile1] != null)) {
									if (rslt[iTile] == null) {
										rslt[iTile] = tiles[iTile1].clone();
										rslt[iTile].scale(wnd[dy+irad][dx+irad]);
									} else {
										rslt[iTile].accumulate(tiles[iTile1], wnd[dy+irad][dx+irad]);
									}
								}
							}
						}
					}
				}
			};
		}
		ImageDtt.startAndJoin(threads);
		return rslt;
	}
	
	
	/**
	 * Convert Transform-domain 2D correlation to phase correlation,
	 * inverse-transform it to pixel domain and return result as 
	 * sparse array of double[] tiles mapped in linescan order. Empty
	 * tiles are null. 
	 * @param gpuQuad 
	 * @param tiles
	 * @param corr_type
	 * @param gpu_fat_zero
	 * @param debug_level
	 * @return sparse array in line-scan order. Each element either null or double[225]
	 */
	public static double [][] convertTDtoPD(
			final GpuQuad       gpuQuad,
			final TDCorrTile [] tiles,
			final int           corr_type, // 0xFE
			final double        gpu_fat_zero,
			final int           debug_level
			){
		final int corr_size_td =    4 * GPUTileProcessor.DTT_SIZE * GPUTileProcessor.DTT_SIZE;
		final int gpu_corr_rad =    GPUTileProcessor.DTT_SIZE -1;
		final int [] indices = new int [tiles.length];
		Arrays.fill(indices, -1); // not needed if not used
		final int num_tiles = getNumTiles(
				tiles,
				indices);
		final float [] fcorr_data_out = new float [corr_size_td * num_tiles];
		final float [] fcorr_weights = new float [num_tiles];
		final int [] corr_indices = new int [num_tiles];
		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
						int indx = indices[iTile];
						float [] data = tiles[iTile].getFloatData();
						System.arraycopy(
								data,
								0,
								fcorr_data_out,
								indx * corr_size_td,
								corr_size_td);
						fcorr_weights[indx] = (float) tiles[iTile].getWeight();
						corr_indices[indx] = (iTile << GPUTileProcessor.CORR_NTILE_SHIFT) | corr_type; // 0xfe ?
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		// set GPU memory
		gpuQuad.setCorrIndicesTdData(
				num_tiles,         // int    num_tiles,  // corr_indices, fdata may be longer than needed
				corr_indices,      // int [] corr_indices,
				fcorr_data_out);   // float [] fdata)
		gpuQuad.execCorr2D_normalize(
				false, // boolean combo, // normalize combo correlations (false - per-pair ones) 
				gpu_fat_zero,            // double fat_zero);
				fcorr_weights,           // fcorr_weights,           // float [] fcorr_weights, // null or one per correlation tile (num_corr_tiles) to divide fat zero2
				gpu_corr_rad);           // int corr_radius
		final float [][] fcorr2D = gpuQuad.getCorr2D(gpu_corr_rad); //  int corr_rad);
		final double [][] mapped_corrs = new double [tiles.length][];
		ai.set(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iCorrTile = ai.getAndIncrement(); iCorrTile < corr_indices.length; iCorrTile = ai.getAndIncrement()) {
						int iTile = (corr_indices[iCorrTile] >> GPUTileProcessor.CORR_NTILE_SHIFT);
						float [] ftile = fcorr2D[iCorrTile];
						mapped_corrs[iTile] = new double [ftile.length];
						for (int i = 0; i < ftile.length; i++) {
							mapped_corrs[iTile][i] = ftile[i];
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return mapped_corrs;
	}
	
	public static double [][] getMismatchVector(
			final double [][] tiles,
			double            rmax, 
			final double      centroid_radius, // 0 - all same weight, > 0 cosine(PI/2*sqrt(dx^2+dy^2)/rad)
			final int         n_recenter){ //  re-center window around new maximum. 0 -no refines (single-pass)
		double [][] vector_field = new double [tiles.length][];
		final int corr_size = 2 * GPUTileProcessor.DTT_SIZE - 1;
		final Thread[] threads = ImageDtt.newThreadArray();
		final AtomicInteger ai = new AtomicInteger(0);
		final AtomicInteger ati = new AtomicInteger(0);
		final double [] th_max = new double [threads.length];
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					int thread_num = ati.getAndIncrement();
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
						for (double d:tiles[iTile]) {
							if (d > th_max[thread_num]) {
								th_max[thread_num] = d;
							}
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		double amax = th_max[0];
		for (int i = 1; i < th_max.length; i++) {
			if (th_max[i] > amax) {
				amax = th_max[i];
			}
		}
		final double min_str = rmax * amax;
		ai.set(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int iTile = ai.getAndIncrement(); iTile < tiles.length; iTile = ai.getAndIncrement()) if (tiles[iTile] != null) {
						double [] mv = Correlation2d.getMaxXYCm( // last, average
								tiles[iTile],    // corrs.length-1], // double [] data,
								corr_size,       // int       data_width,      //  = 2 * transform_size - 1;
								centroid_radius, // double    radius, // 0 - all same weight, > 0 cosine(PI/2*sqrt(dx^2+dy^2)/rad)
								n_recenter,      // int       refine, //  re-center window around new maximum. 0 -no refines (single-pass)
								null,            // boolean [] fpn_mask,
								false,           // boolean    ignore_border, // only if fpn_mask != null - ignore tile if maximum touches fpn_mask
								false);          // boolean   debug)
						if (mv[2] > min_str) {
							vector_field[iTile] = new double [] {mv[0], mv[1], mv[2]-min_str};
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return vector_field;
	}
	
	
	/**
	 * Get GPU TD data after interscene correlation of 2 scenes (only use
	 * combo of all channels)  
	 * @param gpuQuad GPU quad instance
	 * @return TDCorrTile [] array, with weight equal to number of channels
	 *         combined (normally 16)
	 */
	public static TDCorrTile [] getFromGpu(
			GpuQuad gpuQuad) {
		int tilesX =     gpuQuad.getImageWidth()  / GpuQuad.getDttSize();
		int tilesY =     gpuQuad.getImageHeight() / GpuQuad.getDttSize();
		TDCorrTile [] tiles = new TDCorrTile[tilesX * tilesY];
		int num_pairs = gpuQuad.getNumCamsInter() + 1;
		int index_combo = num_pairs - 1;
		final int corr_size_td = 4 * GPUTileProcessor.DTT_SIZE * GPUTileProcessor.DTT_SIZE;
		final int [] indices = gpuQuad.getCorrIndices(); // also sets num_corr_tiles
		final float [] fdata = gpuQuad.getCorrTdData();
		int num_tiles = gpuQuad.getNumCorrTiles() / num_pairs;

		final Thread[] threads = ImageDtt.newThreadArray(ImageDtt.THREADS_MAX);
		final AtomicInteger ai = new AtomicInteger(0);
		for (int ithread = 0; ithread < threads.length; ithread++) {
			threads[ithread] = new Thread() {
				public void run() {
					for (int nt = ai.getAndIncrement(); nt < num_tiles; nt = ai.getAndIncrement()) {
						int nTile = (indices[nt * num_pairs] >> GPUTileProcessor.CORR_NTILE_SHIFT);
						int fdata_offset = (nt * num_pairs + index_combo) * corr_size_td;
						double [] tile_data = new double [corr_size_td];
						copy_data: {
							for (int i = 0; i < tile_data.length; i++) {
								double d = fdata[fdata_offset+i];
								if (Double.isNaN(d)) {
									break copy_data; // accumulated tile may be NaN
								}
								tile_data[i] =d; 
							}
							tiles[nTile] = new TDCorrTile (index_combo, tile_data); //  default weight = 16?
						}
					}
				}
			};
		}		      
		ImageDtt.startAndJoin(threads);
		return tiles;
	}
}
