Commit c3fb4172 authored by Andrey Filippov's avatar Andrey Filippov

Initial testing of LMA for fitting pairs of ortho images

parent 411beb9e
...@@ -52,16 +52,19 @@ public class ComboMatch { ...@@ -52,16 +52,19 @@ public class ComboMatch {
double [][][] image_enuatr = {{{0,0,0},{0,0,0}},{{0,0,0},{0,0,0}}}; double [][][] image_enuatr = {{{0,0,0},{0,0,0}},{{0,0,0},{0,0,0}}};
int gpu_width= clt_parameters.imp.rln_gpu_width; // 3008; int gpu_width= clt_parameters.imp.rln_gpu_width; // 3008;
int gpu_height= clt_parameters.imp.rln_gpu_height; // 3008; int gpu_height= clt_parameters.imp.rln_gpu_height; // 3008;
int zoom_lev = -4; // 0; // +1 - zoom in twice, -1 - zoom out twice int zoom_lev = -3; // 0; // +1 - zoom in twice, -1 - zoom out twice
boolean use_alt = false; boolean use_alt = false;
boolean show_centers = true; boolean show_centers = true;
boolean use_saved_collection = false; boolean use_saved_collection = true; // false;
boolean save_collection = true; boolean save_collection = true;
boolean process_correlation = true; // use false to save new version of data
GenericJTabbedDialog gd = new GenericJTabbedDialog("Set image pair",1200,800); GenericJTabbedDialog gd = new GenericJTabbedDialog("Set image pair",1200,800);
gd.addStringField ("Image list full path", files_list_path, 180, "Image list full path."); gd.addStringField ("Image list full path", files_list_path, 180, "Image list full path.");
gd.addStringField ("Maps collection save path", orthoMapsCollection_path, 180, "Save path for serialized map collection data."); gd.addStringField ("Maps collection save path", orthoMapsCollection_path, 180, "Save path for serialized map collection data.");
gd.addCheckbox ("Use saved maps collection", use_saved_collection, "If false - use files list."); gd.addCheckbox ("Use saved maps collection", use_saved_collection, "If false - use files list.");
gd.addCheckbox ("Save maps collection", save_collection, "If false - use files list."); gd.addCheckbox ("Save maps collection", save_collection, "Save maps collection to be able to restore.");
gd.addCheckbox ("Process correlations", process_correlation, "false to skip to just regenerate new save file.");
//
// for (int n = 0; n < image_paths_pre.length; n++) { // for (int n = 0; n < image_paths_pre.length; n++) {
// gd.addStringField ("Image path "+n, image_paths_pre[n], 180, "Image "+n+" full path w/o ext"); // gd.addStringField ("Image path "+n, image_paths_pre[n], 180, "Image "+n+" full path w/o ext");
// } // }
...@@ -92,6 +95,7 @@ public class ComboMatch { ...@@ -92,6 +95,7 @@ public class ComboMatch {
orthoMapsCollection_path = gd.getNextString(); orthoMapsCollection_path = gd.getNextString();
use_saved_collection = gd.getNextBoolean(); use_saved_collection = gd.getNextBoolean();
save_collection = gd.getNextBoolean(); save_collection = gd.getNextBoolean();
process_correlation= gd.getNextBoolean();
for (int n = 0; n < image_enuatr.length; n++) { for (int n = 0; n < image_enuatr.length; n++) {
image_enuatr[n][0][0] = gd.getNextNumber(); image_enuatr[n][0][0] = gd.getNextNumber();
...@@ -142,7 +146,7 @@ public class ComboMatch { ...@@ -142,7 +146,7 @@ public class ComboMatch {
origin); // int [] origin){ origin); // int [] origin){
imp_alt.show(); imp_alt.show();
} }
if (process_correlation) {
System.out.println("Setting up GPU"); System.out.println("Setting up GPU");
if (GPU_QUAD_AFFINE == null) { if (GPU_QUAD_AFFINE == null) {
try { try {
...@@ -169,6 +173,7 @@ public class ComboMatch { ...@@ -169,6 +173,7 @@ public class ComboMatch {
affines, // double [][][] affines, // on top of GPS offsets affines, // double [][][] affines, // on top of GPS offsets
zoom_lev, // int zoom_lev, zoom_lev, // int zoom_lev,
debugLevel); // final int debugLevel) debugLevel); // final int debugLevel)
}
if (save_collection) { if (save_collection) {
try { try {
maps_collection.writeOrthoMapsCollection(orthoMapsCollection_path); maps_collection.writeOrthoMapsCollection(orthoMapsCollection_path);
...@@ -367,6 +372,7 @@ public class ComboMatch { ...@@ -367,6 +372,7 @@ public class ComboMatch {
gpu_width, // final int img_width, gpu_width, // final int img_width,
null, // Rectangle woi, // if null, use full GPU window null, // Rectangle woi, // if null, use full GPU window
affine, // final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images affine, // final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images
null, // TpTask [][] tp_tasks_o,
false, // final boolean batch_mode, false, // final boolean batch_mode,
debugLevel); // final int debugLevel); debugLevel); // final int debugLevel);
// renderFromTD ( // renderFromTD (
...@@ -395,6 +401,7 @@ public class ComboMatch { ...@@ -395,6 +401,7 @@ public class ComboMatch {
final int img_width, final int img_width,
Rectangle woi, // if null, use full GPU window Rectangle woi, // if null, use full GPU window
final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images
TpTask [][] tp_tasks_o,
final boolean batch_mode, final boolean batch_mode,
final int debugLevel) { final int debugLevel) {
int [] wh = {img_width, fpixels[0].length/img_width}; int [] wh = {img_width, fpixels[0].length/img_width};
...@@ -403,6 +410,9 @@ public class ComboMatch { ...@@ -403,6 +410,9 @@ public class ComboMatch {
img_width, // final int img_width, img_width, // final int img_width,
woi, // Rectangle woi, woi, // Rectangle woi,
affine); // final double [][][] affine // [2][2][3] affine coefficients to translate common to 2 images affine); // final double [][][] affine // [2][2][3] affine coefficients to translate common to 2 images
if (tp_tasks_o != null) {
for (int i = 0; i < tp_tasks_o.length; i++) tp_tasks_o[i] = tp_tasks[i];
}
boolean is_aux = true; boolean is_aux = true;
boolean is_mono = true; boolean is_mono = true;
boolean is_lwir = true; boolean is_lwir = true;
......
...@@ -62,6 +62,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -62,6 +62,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
public double [][] affine = new double[][] {{1,0,0},{0,1,0}}; // relative to vert_meters[] public double [][] affine = new double[][] {{1,0,0},{0,1,0}}; // relative to vert_meters[]
public double orig_pix_meters; public double orig_pix_meters;
public double [] vert_meters; // offset of the image vertical in meters (scale-invariant) public double [] vert_meters; // offset of the image vertical in meters (scale-invariant)
public int orig_width;
public int orig_height;
public transient FloatImageData orig_image; public transient FloatImageData orig_image;
public transient FloatImageData alt_image; public transient FloatImageData alt_image;
public int orig_zoom_level; public int orig_zoom_level;
...@@ -79,6 +81,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -79,6 +81,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
// affine is not transient // affine is not transient
// orig_pix_meters is not transient // orig_pix_meters is not transient
// vert_meters is not transient // vert_meters is not transient
// orig_width is not transient
// orig_height is not transient
// orig_image does not need to be saved // orig_image does not need to be saved
// alt_image does not need to be saved // alt_image does not need to be saved
// orig_zoom_level is not transient // orig_zoom_level is not transient
...@@ -98,6 +102,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -98,6 +102,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
// affine is not transient // affine is not transient
// orig_pix_meters is not transient // orig_pix_meters is not transient
// vert_meters is not transient // vert_meters is not transient
// orig_width is not transient
// orig_height is not transient
// orig_image was not saved // orig_image was not saved
// alt_image was not saved // alt_image was not saved
images = new HashMap <Integer, FloatImageData>(); // field images was not saved images = new HashMap <Integer, FloatImageData>(); // field images was not saved
...@@ -117,6 +123,12 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -117,6 +123,12 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
gpu_height = height; gpu_height = height;
} }
public int getWidth() {
return orig_width;
}
public int getHeight() {
return orig_height;
}
// Generate ALT image path from the GEO // Generate ALT image path from the GEO
public static String getAltPath(String path) { public static String getAltPath(String path) {
int p1 = path.lastIndexOf("."); int p1 = path.lastIndexOf(".");
...@@ -151,6 +163,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -151,6 +163,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
double height_meters = height_pix * orig_pix_meters; double height_meters = height_pix * orig_pix_meters;
vert_meters[1] = height_meters-vert_meters[1]; vert_meters[1] = height_meters-vert_meters[1];
} }
orig_width = ElphelTiffReader.getWidth(imp_prop);
orig_height = ElphelTiffReader.getHeight(imp_prop);
orig_zoom_level = FloatImageData.getZoomLevel(orig_pix_meters); orig_zoom_level = FloatImageData.getZoomLevel(orig_pix_meters);
orig_zoom_valid = FloatImageData.isZoomValid(orig_pix_meters); orig_zoom_valid = FloatImageData.isZoomValid(orig_pix_meters);
...@@ -254,8 +268,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -254,8 +268,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
String full_name = path.substring(path.lastIndexOf(Prefs.getFileSeparator()) + 1); String full_name = path.substring(path.lastIndexOf(Prefs.getFileSeparator()) + 1);
ImagePlus imp = ShowDoubleFloatArrays.makeArrays( ImagePlus imp = ShowDoubleFloatArrays.makeArrays(
orig_image.data, // float[] pixels, orig_image.data, // float[] pixels,
orig_image.width, getWidth(),
orig_image.height, getHeight(),
full_name); full_name);
if (show_markers) { if (show_markers) {
PointRoi roi = new PointRoi(); PointRoi roi = new PointRoi();
...@@ -296,8 +310,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -296,8 +310,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
*/ */
public double [][] get4SourceCornersMeters(){ public double [][] get4SourceCornersMeters(){
FloatImageData orig_image = getImageData(); FloatImageData orig_image = getImageData();
double width_meters = orig_image.width * orig_pix_meters; double width_meters = getWidth() * orig_pix_meters;
double height_meters = orig_image.height * orig_pix_meters; double height_meters = getHeight() * orig_pix_meters;
return new double[][] { // CW from TL return new double[][] { // CW from TL
{ - vert_meters[0], - vert_meters[1]}, { - vert_meters[0], - vert_meters[1]},
{width_meters - vert_meters[0], - vert_meters[1]}, {width_meters - vert_meters[0], - vert_meters[1]},
...@@ -440,8 +454,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{ ...@@ -440,8 +454,8 @@ public class OrthoMap implements Comparable <OrthoMap>, Serializable{
rscale *= 2; rscale *= 2;
} }
final int frscale = rscale; final int frscale = rscale;
int swidth = orig_image.width; int swidth = getWidth();
int sheight = orig_image.height; int sheight = getHeight();
final float [] spix = orig_image.data; final float [] spix = orig_image.data;
final int width = (swidth+frscale-1)/frscale; final int width = (swidth+frscale-1)/frscale;
final int height = (sheight+frscale-1)/frscale; final int height = (sheight+frscale-1)/frscale;
......
...@@ -20,6 +20,8 @@ import java.util.concurrent.atomic.AtomicInteger; ...@@ -20,6 +20,8 @@ import java.util.concurrent.atomic.AtomicInteger;
import com.elphel.imagej.cameras.CLTParameters; import com.elphel.imagej.cameras.CLTParameters;
import com.elphel.imagej.common.ShowDoubleFloatArrays; import com.elphel.imagej.common.ShowDoubleFloatArrays;
import com.elphel.imagej.gpu.GPUTileProcessor;
import com.elphel.imagej.gpu.TpTask;
import com.elphel.imagej.tileprocessor.ImageDtt; import com.elphel.imagej.tileprocessor.ImageDtt;
import com.elphel.imagej.tileprocessor.TDCorrTile; import com.elphel.imagej.tileprocessor.TDCorrTile;
...@@ -441,6 +443,8 @@ public class OrthoMapsCollection implements Serializable{ ...@@ -441,6 +443,8 @@ public class OrthoMapsCollection implements Serializable{
double [][][] affines, // here in meters, relative to vertical points double [][][] affines, // here in meters, relative to vertical points
int zoom_lev, int zoom_lev,
final int debugLevel){ final int debugLevel){
boolean show_gpu_img = true; // (debugLevel > 1);
boolean show_tile_centers = true; // (debugLevel > 1);
double [][] bounds_overlap_meters = getOverlapMeters( double [][] bounds_overlap_meters = getOverlapMeters(
gpu_pair[0], // int ref_index, gpu_pair[0], // int ref_index,
gpu_pair[1], // int other_index) gpu_pair[1], // int other_index)
...@@ -469,9 +473,9 @@ public class OrthoMapsCollection implements Serializable{ ...@@ -469,9 +473,9 @@ public class OrthoMapsCollection implements Serializable{
tlo_rect_metric[0][0] = bounds_overlap_meters[0][0]; // relative to ref vert_meters tlo_rect_metric[0][0] = bounds_overlap_meters[0][0]; // relative to ref vert_meters
tlo_rect_metric[0][1] = bounds_overlap_meters[1][0]; // vert_meters tlo_rect_metric[0][1] = bounds_overlap_meters[1][0]; // vert_meters
tlo_rect_metric[1][0] = bounds_overlap_meters[0][0] // relative to other vert_meters tlo_rect_metric[1][0] = bounds_overlap_meters[0][0] // relative to other vert_meters
- rd[0] + ortho_maps[gpu_pair[1]].vert_meters[0]- ortho_maps[gpu_pair[0]].vert_meters[0]; - rd[0]; // + ortho_maps[gpu_pair[1]].vert_meters[0]- ortho_maps[gpu_pair[0]].vert_meters[0];
tlo_rect_metric[1][1] = bounds_overlap_meters[1][0] tlo_rect_metric[1][1] = bounds_overlap_meters[1][0]
- rd[1] + ortho_maps[gpu_pair[1]].vert_meters[1]- ortho_maps[gpu_pair[0]].vert_meters[1]; - rd[1]; // + ortho_maps[gpu_pair[1]].vert_meters[1]- ortho_maps[gpu_pair[0]].vert_meters[1];
double [][] tlo_src_metric = new double[tlo_rect_metric.length][2]; // relative to it's own vert_meters double [][] tlo_src_metric = new double[tlo_rect_metric.length][2]; // relative to it's own vert_meters
for (int n=0; n <tlo_src_metric.length; n++) { for (int n=0; n <tlo_src_metric.length; n++) {
...@@ -501,6 +505,17 @@ public class OrthoMapsCollection implements Serializable{ ...@@ -501,6 +505,17 @@ public class OrthoMapsCollection implements Serializable{
} }
gpu_pair_img[n] = ortho_maps[gpu_pair[n]].getPaddedGPU (zoom_lev); // int zoom_level, gpu_pair_img[n] = ortho_maps[gpu_pair[n]].getPaddedGPU (zoom_lev); // int zoom_level,
} }
if (show_gpu_img) {
String [] dbg_titles = {ortho_maps[gpu_pair[0]].getName(),ortho_maps[gpu_pair[1]].getName()};
ShowDoubleFloatArrays.showArrays(
gpu_pair_img,
OrthoMap.gpu_width,
OrthoMap.gpu_height,
true,
"gpu_img",
dbg_titles);
}
Rectangle woi = new Rectangle(0, 0, overlap_wh_pixel[0], overlap_wh_pixel[1]); Rectangle woi = new Rectangle(0, 0, overlap_wh_pixel[0], overlap_wh_pixel[1]);
if (woi.width > OrthoMap.gpu_width) { if (woi.width > OrthoMap.gpu_width) {
if (debugLevel > -3) { if (debugLevel > -3) {
...@@ -517,6 +532,7 @@ public class OrthoMapsCollection implements Serializable{ ...@@ -517,6 +532,7 @@ public class OrthoMapsCollection implements Serializable{
final int gpu_width = OrthoMap.gpu_width; // static final int gpu_width = OrthoMap.gpu_width; // static
// uses fixed_size gpu image size // uses fixed_size gpu image size
// TDCorrTile [] td_corr_tiles = // TDCorrTile [] td_corr_tiles =
TpTask [][] tp_tasks = new TpTask [2][];
double [][][] vector_field = double [][][] vector_field =
ComboMatch.rectilinearVectorField(//rectilinearCorrelate_TD( // scene0/scene1 ComboMatch.rectilinearVectorField(//rectilinearCorrelate_TD( // scene0/scene1
clt_parameters, // final CLTParameters clt_parameters, clt_parameters, // final CLTParameters clt_parameters,
...@@ -524,9 +540,69 @@ public class OrthoMapsCollection implements Serializable{ ...@@ -524,9 +540,69 @@ public class OrthoMapsCollection implements Serializable{
gpu_width, // final int img_width, gpu_width, // final int img_width,
woi, // Rectangle woi, // if null, use full GPU window woi, // Rectangle woi, // if null, use full GPU window
affines_gpu, // final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images affines_gpu, // final double [][][] affine, // [2][2][3] affine coefficients to translate common to 2 images
tp_tasks, // TpTask [][] tp_tasks_o,
false, // final boolean batch_mode, false, // final boolean batch_mode,
debugLevel); // final int debugLevel); debugLevel); // final int debugLevel);
// may use tl_rect_metric to remap to the original image // may use tl_rect_metric to remap to the original image
double [][] tile_centers = new double [vector_field[0].length][];
int tilesX = gpu_width/GPUTileProcessor.DTT_SIZE;
for (TpTask task: tp_tasks[1]) {
int ti = task.getTileY() * tilesX + task.getTileX();
tile_centers[ti] = task.getDoubleCenterXY();
}
if (show_tile_centers){
double [][] dbg_img = new double [6][tile_centers.length];
String [] dbg_titles = {"cX","cY","px0","py0", "px1","py1"};
for (int i = 0; i< dbg_img.length;i++) Arrays.fill(dbg_img[i], Double.NaN);
for (int t = 0; t < tp_tasks[0].length; t++) {
TpTask task0 = tp_tasks[0][t];
TpTask task1 = tp_tasks[1][t];
int ti = task0.getTileY() * tilesX + task0.getTileX();
dbg_img[0][ti] = task0.getDoubleCenterXY()[0]; // same for task0, task1
dbg_img[1][ti] = task1.getDoubleCenterXY()[1];
dbg_img[2][ti] = task0.getXY()[0][0];
dbg_img[3][ti] = task0.getXY()[0][1];
dbg_img[4][ti] = task1.getXY()[0][0];
dbg_img[5][ti] = task1.getXY()[0][1];
} // getXY()
ShowDoubleFloatArrays.showArrays(
dbg_img,
tilesX,
tile_centers.length/tilesX,
true,
"tile_centers",
dbg_titles);
}
OrthoPairLMA orthoPairLMA = new OrthoPairLMA();
// vector_field[1] - neighbors
double lambda = 0.1;
double lambda_scale_good = 0.5;
double lambda_scale_bad = 8.0;
double lambda_max = 100;
double rms_diff = 0.001;
int num_iter = 20;
boolean last_run = false;
orthoPairLMA.prepareLMA(
// will always calculate relative affine, starting with unity
tilesX, // int width,
vector_field[1], // double [][] vector_XYS, // optical flow X,Y, confidence obtained from the correlate2DIterate()
tile_centers, // double [][] centers, // tile centers (in pixels)
null, // double [] weights_extra, // optional, may be null
true, // boolean first_run,
debugLevel); // final int debug_level)
int lma_rslt = orthoPairLMA.runLma( // <0 - failed, >=0 iteration number (1 - immediately)
lambda, // double lambda, // 0.1
lambda_scale_good, // double lambda_scale_good,// 0.5
lambda_scale_bad, // double lambda_scale_bad, // 8.0
lambda_max, // double lambda_max, // 100
rms_diff, // double rms_diff, // 0.001
num_iter, // int num_iter, // 20
last_run, // boolean last_run,
debugLevel); // int debug_level)
System.out.println("LMA result = "+lma_rslt);
// analyze result, re-run correlation
/* /*
if (show_vector_field) { if (show_vector_field) {
......
/**
**
** OrthoPairLMA - Fit a pair of fixed-scale orthographic image by modifying
** affine transform matrix of the second one using 2D phase
** correlation implemented in GPU
**
** Copyright (C) 2024 Elphel, Inc.
**
** -----------------------------------------------------------------------------**
**
** OrthoPairLMA.java is free software: you can redistribute it and/or modify
** it under the terms of the GNU General Public License as published by
** the Free Software Foundation, either version 3 of the License, or
** (at your option) any later version.
**
** This program is distributed in the hope that it will be useful,
** but WITHOUT ANY WARRANTY; without even the implied warranty of
** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
** GNU General Public License for more details.
**
** You should have received a copy of the GNU General Public License
** along with this program. If not, see <http://www.gnu.org/licenses/>.
** -----------------------------------------------------------------------------**
**
*/
package com.elphel.imagej.orthomosaic;
import java.awt.Rectangle;
import java.util.concurrent.atomic.AtomicInteger;
import com.elphel.imagej.tileprocessor.ImageDtt;
import com.elphel.imagej.tileprocessor.QuadCLT;
import Jama.Matrix;
public class OrthoPairLMA {
private boolean thread_invariant= true;
private int N = 0;
private Rectangle woi = null;
private int width;
private double [] last_rms = null; // {rms, rms_pure}, matching this.vector
private double [] good_or_bad_rms = null; // just for diagnostics, to read last (failed) rms
private double [] initial_rms = null; // {rms, rms_pure}, first-calcualted rms
private double [] parameters_vector = null;
// private double [] x_vector = null; // not used. Save total weight
private double [] y_vector = null;
private double [][] tile_centers = null;
private double weight = 0; // total weight
private double [] weights; // normalized so sum is 1.0 for all - samples and extra regularization
// private double pure_weight; // weight of samples only
private double [] last_ymfx = null;
private double [][] last_jt = null;
public void prepareLMA(
// will always calculate relative affine, starting with unity
int width, // tilesX
double [][] vector_XYS, // optical flow X,Y, confidence obtained from the correlate2DIterate()
double [][] centers, // tile centers (in pixels)
double [] weights_extra, // optional, may be null
boolean first_run,
final int debug_level) {
tile_centers = centers;
this.width = width;
int height = vector_XYS.length / width;
int min_x = width, min_y=height, max_x = -1, max_y=-1;
for (int tile = 0; tile < vector_XYS.length; tile++) if ((vector_XYS[tile] !=null)) {
int tileX = tile % width;
int tileY = tile / width;
if (tileX < min_x) min_x = tileX;
if (tileX > max_x) max_x = tileX;
if (tileY < min_y) min_y = tileY;
if (tileY > max_y) max_y = tileY;
}
woi = new Rectangle (min_x, min_y, max_x - min_x + 1, max_y - min_y + 1);
N = woi.width*woi.height;
parameters_vector = new double [] {1,0,0,1,0,0};
setSamplesWeightsYCenters(
vector_XYS,
weights_extra, // null or additional weights (such as elevation-based)
centers);
last_jt = new double [parameters_vector.length][];
if (debug_level > 1) {
System.out.println("prepareLMA() 1");
}
// tile_centers = new double [N];
last_rms = new double [2];
// double [] fx = getFxDerivs(
// parameters_vector, // double [] vector,
// last_jt, // final double [][] jt, // should be null or initialized with [vector.length][]
// debug_level); // final int debug_level)
// last_ymfx = getYminusFxWeighted(
// fx, // final double [] fx,
// last_rms); // final double [] rms_fp // null or [2]
initial_rms = last_rms.clone();
good_or_bad_rms = this.last_rms.clone();
}
private void setSamplesWeightsYCenters(
final double [][] vector_XYS,
final double [] weights_extra, // null or additional weights (such as elevation-based)
final double [][] centers)
{
//num_components 2 - old, 3 - with disparity
this.weights = new double [2*N]; // same for X and Y
y_vector = new double[2*N];
tile_centers = new double [N][];
final Thread[] threads = ImageDtt.newThreadArray();
final AtomicInteger ai = new AtomicInteger(0);
final AtomicInteger ati = new AtomicInteger(0);
final double [] sw_arr = new double [threads.length];
// double sum_weights;
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
int thread_num = ati.getAndIncrement();
for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
int tileX = iTile % woi.width + woi.x;
int tileY = iTile / woi.width + woi.y;
int aTile = tileY * width + tileX;
if ((vector_XYS[aTile] != null) && (centers[aTile] != null)) {
double w = vector_XYS[aTile][2];
if (weights_extra != null) w *= weights_extra[aTile];
if (Double.isNaN(w)) w = 0;
sw_arr[thread_num] += 2*w;
weights [2*iTile + 0] = w;
weights [2*iTile + 1] = w;
y_vector[2*iTile + 0] = vector_XYS[aTile][0];
y_vector[2*iTile + 1] = vector_XYS[aTile][1];
tile_centers[iTile] = centers[aTile];
}
}
}
};
}
ImageDtt.startAndJoin(threads);
weight = 0.0;
for (double w:sw_arr) {
weight += w;
}
if (weight <= 1E-8) {
System.out.println("!!!!!! setSamplesWeights(): sum_weights="+weight+" <= 1E-8");
}
ai.set(0);
final double s = 1.0/weight; // Was 0.5 - already taken care of
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) {
weights[2 * iTile ] *= s;
weights[2 * iTile + 1] *= s;
}
}
};
}
ImageDtt.startAndJoin(threads);
}
public int runLma( // <0 - failed, >=0 iteration number (1 - immediately)
double lambda, // 0.1
double lambda_scale_good,// 0.5
double lambda_scale_bad, // 8.0
double lambda_max, // 100
double rms_diff, // 0.001
int num_iter, // 20
boolean last_run,
int debug_level)
{
boolean [] rslt = {false,false};
this.last_rms = null; // remove?
int iter = 0;
for (iter = 0; iter < num_iter; iter++) {
rslt = lmaStep(
lambda,
rms_diff,
debug_level);
if (rslt == null) {
return -1; // false; // need to check
}
if (debug_level > 1) {
System.out.println("LMA step "+iter+": {"+rslt[0]+","+rslt[1]+"} full RMS= "+good_or_bad_rms[0]+
" ("+initial_rms[0]+"), pure RMS="+good_or_bad_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
}
if (rslt[1]) {
break;
}
if (rslt[0]) { // good
lambda *= lambda_scale_good;
} else {
lambda *= lambda_scale_bad;
if (lambda > lambda_max) {
break; // not used in lwir
}
}
}
if (rslt[0]) { // better
if (iter >= num_iter) { // better, but num tries exceeded
if (debug_level > 1) System.out.println("Step "+iter+": Improved, but number of steps exceeded maximal");
} else {
if (debug_level > 1) System.out.println("Step "+iter+": LMA: Success");
}
} else { // improved over initial ?
if (last_rms[0] < initial_rms[0]) { // NaN
rslt[0] = true;
if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge, but result improved over initial");
} else {
if (debug_level > 1) System.out.println("Step "+iter+": Failed to converge");
}
}
boolean show_intermediate = true;
if (show_intermediate && (debug_level > 0)) {
System.out.println("LMA: full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
}
if (debug_level > 2){
System.out.println("iteration="+iter);
}
if (debug_level > 0) {
if ((debug_level > 1) || last_run) { // (iter == 1) || last_run) {
if (!show_intermediate) {
System.out.println("LMA: iter="+iter+", full RMS="+last_rms[0]+" ("+initial_rms[0]+"), pure RMS="+last_rms[1]+" ("+initial_rms[1]+") + lambda="+lambda);
}
}
}
if ((debug_level > -2) && !rslt[0]) { // failed
if ((debug_level > 1) || (iter == 1) || last_run) {
System.out.println("LMA failed on iteration = "+iter);
}
System.out.println();
}
return rslt[0]? iter : -1;
}
private boolean [] lmaStep(
double lambda,
double rms_diff,
int debug_level) {
boolean [] rslt = {false,false};
// maybe the following if() branch is not needed - already done in prepareLMA !
if (this.last_rms == null) { //first time, need to calculate all (vector is valid)
last_rms = new double[2];
if (debug_level > 1) {
System.out.println("lmaStep(): first step");
}
double [] fx = getFxDerivs(
parameters_vector, // double [] vector,
last_jt, // final double [][] jt, // should be null or initialized with [vector.length][]
debug_level); // final int debug_level)
last_ymfx = getYminusFxWeighted(
fx, // final double [] fx,
last_rms); // final double [] rms_fp // null or [2]
this.initial_rms = this.last_rms.clone();
this.good_or_bad_rms = this.last_rms.clone();
if (debug_level > -1) { // temporary
/*
dbgYminusFxWeight(
this.last_ymfx,
this.weights,
"Initial_y-fX_after_moving_objects");
*/
}
if (last_ymfx == null) {
return null; // need to re-init/restart LMA
}
// TODO: Restore/implement
if (debug_level > 3) {
/*
dbgJacobians(
corr_vector, // GeometryCorrection.CorrVector corr_vector,
1E-5, // double delta,
true); //boolean graphic)
*/
}
}
Matrix y_minus_fx_weighted = new Matrix(this.last_ymfx, this.last_ymfx.length);
Matrix wjtjlambda = new Matrix(getWJtJlambda(
lambda, // *10, // temporary
this.last_jt)); // double [][] jt)
if (debug_level>2) {
System.out.println("JtJ + lambda*diag(JtJ");
wjtjlambda.print(18, 6);
}
Matrix jtjl_inv = null;
try {
jtjl_inv = wjtjlambda.inverse(); // check for errors
} catch (RuntimeException e) {
rslt[1] = true;
if (debug_level > 0) {
System.out.println("Singular Matrix!");
}
return rslt;
}
if (debug_level>2) {
System.out.println("(JtJ + lambda*diag(JtJ).inv()");
jtjl_inv.print(18, 6);
}
//last_jt has NaNs
Matrix jty = (new Matrix(this.last_jt)).times(y_minus_fx_weighted);
if (debug_level>2) {
System.out.println("Jt * (y-fx)");
jty.print(18, 6);
}
Matrix mdelta = jtjl_inv.times(jty);
if (debug_level>2) {
System.out.println("mdelta");
mdelta.print(18, 6);
}
double scale = 1.0;
double [] delta = mdelta.getColumnPackedCopy();
double [] new_vector = parameters_vector.clone();
for (int i = 0; i < parameters_vector.length; i++) {
new_vector[i] += scale * delta[i];
}
double [] fx = getFxDerivs(
new_vector, // double [] vector,
last_jt, // final double [][] jt, // should be null or initialized with [vector.length][]
debug_level); // final int debug_level)
double [] rms = new double[2];
last_ymfx = getYminusFxWeighted(
// vector_XYS, // final double [][] vector_XYS,
fx, // final double [] fx,
rms); // final double [] rms_fp // null or [2]
if (debug_level > 2) {
/*
dbgYminusFx(this.last_ymfx, "next y-fX");
dbgXY(new_vector, "XY-correction");
*/
}
if (last_ymfx == null) {
return null; // need to re-init/restart LMA
}
this.good_or_bad_rms = rms.clone();
if (rms[0] < this.last_rms[0]) { // improved
rslt[0] = true;
rslt[1] = rms[0] >=(this.last_rms[0] * (1.0 - rms_diff));
this.last_rms = rms.clone();
this.parameters_vector = new_vector.clone();
if (debug_level > 2) {
// print vectors in some format
/*
System.out.print("delta: "+corr_delta.toString()+"\n");
System.out.print("New vector: "+new_vector.toString()+"\n");
System.out.println();
*/
}
} else { // worsened
rslt[0] = false;
rslt[1] = false; // do not know, caller will decide
// restore state
fx = getFxDerivs(
parameters_vector, // double [] vector,
last_jt, // final double [][] jt, // should be null or initialized with [vector.length][]
debug_level); // final int debug_level)
last_ymfx = getYminusFxWeighted(
fx, // final double [] fx,
this.last_rms); // final double [] rms_fp // null or [2]
if (last_ymfx == null) {
return null; // need to re-init/restart LMA
}
if (debug_level > 2) {
/*
dbgJacobians(
corr_vector, // GeometryCorrection.CorrVector corr_vector,
1E-5, // double delta,
true); //boolean graphic)
*/
}
}
return rslt;
}
private double [][] getWJtJlambda( // USED in lwir
final double lambda,
final double [][] jt)
{
final int num_pars = jt.length;
final int num_pars2 = num_pars * num_pars;
final int nup_points = jt[0].length;
final double [][] wjtjl = new double [num_pars][num_pars];
final Thread[] threads = ImageDtt.newThreadArray();
final AtomicInteger ai = new AtomicInteger(0);
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int indx = ai.getAndIncrement(); indx < num_pars2; indx = ai.getAndIncrement()) {
int i = indx / num_pars;
int j = indx % num_pars;
if (j >= i) {
double d = 0.0;
for (int k = 0; k < nup_points; k++) {
if (jt[i][k] != 0) {
d+=0;
}
d += weights[k]*jt[i][k]*jt[j][k];
}
wjtjl[i][j] = d;
if (i == j) {
wjtjl[i][j] += d * lambda;
} else {
wjtjl[j][i] = d;
}
}
}
}
};
}
ImageDtt.startAndJoin(threads);
return wjtjl;
}
private double [] getFxDerivs(
final double [] vector,
final double [][] jt, // should be null or initialized with [vector.length][]
final int debug_level)
{
// final double [][] affine = {{vector[0],vector[1],vector[4]},{vector[2],vector[3],vector[5]}};
final double [] fx = new double [weights.length]; // weights.length]; : weights.length :
if (jt != null) {
for (int i = 0; i < jt.length; i++) {
jt[i] = new double [weights.length]; // weights.length];
}
}
final Thread[] threads = ImageDtt.newThreadArray();
final AtomicInteger ai = new AtomicInteger(0);
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int iTile = ai.getAndIncrement(); iTile < N; iTile = ai.getAndIncrement()) if (tile_centers[iTile] !=null) {
double x = tile_centers[iTile][0];
double y = tile_centers[iTile][1];
double vx = (vector[0] - 1.0) * x + vector[1] * y + vector[4];
double vy = vector[2] * x + (vector[3] - 1.0) * y + vector[5];
fx[2 * iTile + 0] = vx;
fx[2 * iTile + 1] = vy;
if (jt != null) {
jt[0][2 * iTile + 0] = x;
jt[1][2 * iTile + 0] = y;
jt[2][2 * iTile + 0] = 0.0;
jt[3][2 * iTile + 0] = 0.0;
jt[4][2 * iTile + 0] = 1.0;
jt[5][2 * iTile + 0] = 0.0;
jt[0][2 * iTile + 1] = 0.0;
jt[1][2 * iTile + 1] = 0.0;
jt[2][2 * iTile + 1] = x;
jt[3][2 * iTile + 1] = y;
jt[4][2 * iTile + 1] = 0.0;
jt[5][2 * iTile + 1] = 1.0;
}
}
}
};
}
ImageDtt.startAndJoin(threads);
return fx;
}
private double [] getYminusFxWeighted(
final double [] fx,
final double [] rms_fp // null or [2]
) {
if (thread_invariant) {
return getYminusFxWeightedInvariant(fx,rms_fp); // null or [>0]
} else {
return getYminusFxWeightedFast (fx,rms_fp); // null or [>0]
}
}
private double [] getYminusFxWeightedInvariant(
final double [] fx,
final double [] rms_fp // null or [2]
) {
final Thread[] threads = ImageDtt.newThreadArray();
final AtomicInteger ai = new AtomicInteger(0);
final double [] wymfw = new double [fx.length];
double s_rms;
final int num_samples = 2 * N;
final double [] l2_arr = new double [num_samples];
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
double d = y_vector[i] - fx[i];
double wd = d * weights[i];
if (Double.isNaN(wd)) {
System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
}
//double l2 = d * wd;
l2_arr[i] = d * wd;
wymfw[i] = wd;
}
}
};
}
ImageDtt.startAndJoin(threads);
s_rms = 0.0;
for (double l2:l2_arr) {
s_rms += l2;
}
double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0; /pure_weight); shey should be re-normalized after adding regularization
if (rms_fp != null) {
rms_fp[0] = rms;
}
return wymfw;
}
private double [] getYminusFxWeightedFast(
final double [] fx,
final double [] rms_fp // null or [2]
) {
final int num_samples = 2 * N;
final Thread[] threads = ImageDtt.newThreadArray();
final AtomicInteger ai = new AtomicInteger(0);
final double [] wymfw = new double [fx.length];
final AtomicInteger ati = new AtomicInteger(0);
final double [] l2_arr = new double [threads.length];
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
int thread_num = ati.getAndIncrement();
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
double d = y_vector[i] - fx[i];
double wd = d * weights[i];
if (Double.isNaN(wd)) {
System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
}
//double l2 = d * wd;
l2_arr[thread_num] += d * wd;
wymfw[i] = wd;
}
}
};
}
ImageDtt.startAndJoin(threads);
double s_rms = 0.0;
for (double l2:l2_arr) {
s_rms += l2;
}
double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0; /pure_weight); shey should be re-normalized after adding regularization
if (rms_fp != null) {
rms_fp[0] = rms;
}
return wymfw;
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment