Commit d87f98cc authored by Andrey Filippov's avatar Andrey Filippov

fixed LMA bug, commented thread_invariant for LMA

parent 1bbeb0e2
...@@ -1389,6 +1389,7 @@ public class Interscene { ...@@ -1389,6 +1389,7 @@ public class Interscene {
{ {
System.out.println("reAdjustPairsLMAInterscene(): using mb_max_gain="+mb_max_gain); System.out.println("reAdjustPairsLMAInterscene(): using mb_max_gain="+mb_max_gain);
boolean freeze_xy_pull = false; // true; // debugging freezing xy to xy_pull boolean freeze_xy_pull = false; // true; // debugging freezing xy to xy_pull
boolean copy_pull_current = true;
final boolean[] param_select = configured_lma? clt_parameters.ilp.ilma_lma_select : final boolean[] param_select = configured_lma? clt_parameters.ilp.ilma_lma_select :
ErsCorrection.getParamSelect( ErsCorrection.getParamSelect(
!freeze_xy_pull && (!readjust_xy_ims || (reg_weight_xy != 0)), // false only in mode c): freeze X,Y// boolean use_XY !freeze_xy_pull && (!readjust_xy_ims || (reg_weight_xy != 0)), // false only in mode c): freeze X,Y// boolean use_XY
...@@ -1545,7 +1546,7 @@ public class Interscene { ...@@ -1545,7 +1546,7 @@ public class Interscene {
earliest_scene, // int early_index, earliest_scene, // int early_index,
last_scene); // int last_index) last_scene); // int last_index)
} }
if (freeze_xy_pull) { if (copy_pull_current) { // freeze_xy_pull) {
System.out.println("reAdjustPairsLMAInterscene(): freezing X,Y to X,Y pull values"); System.out.println("reAdjustPairsLMAInterscene(): freezing X,Y to X,Y pull values");
for (int nscene = last_scene; nscene >= earliest_scene; nscene--) { for (int nscene = last_scene; nscene >= earliest_scene; nscene--) {
if (scenes_xyzatr_pull[nscene] != null) { if (scenes_xyzatr_pull[nscene] != null) {
...@@ -2261,13 +2262,14 @@ public class Interscene { ...@@ -2261,13 +2262,14 @@ public class Interscene {
ers_scene.ers_wxyz_center_dt = ers_ref.ers_wxyz_center_dt.clone(); ers_scene.ers_wxyz_center_dt = ers_ref.ers_wxyz_center_dt.clone();
*/ */
} }
// TODO: save ers_scene.ers_watr_center_dt and ers_scene.ers_wxyz_center_dt // TODO: save ers_scene.ers_watr_center_dt and ers_scene.ers_wxyz_center_dt
intersceneLma.prepareLMA( intersceneLma.prepareLMA(
camera_xyz0, // final double [] scene_xyz0, // camera center in world coordinates (or null to use instance) camera_xyz0, // final double [] scene_xyz0, // camera center in world coordinates (or null to use instance)
camera_atr0, // final double [] scene_atr0, // camera orientation relative to world frame (or null to use instance) camera_atr0, // final double [] scene_atr0, // camera orientation relative to world frame (or null to use instance)
scene_xyz_pull, // final double [] scene_xyz_pull, // if both are not null, specify target values to pull to scene_xyz_pull, // final double [] scene_xyz_pull, // if both are not null, specify target values to pull to
scene_atr_pull, // final double [] scene_atr_pull, // scene_atr_pull, // final double [] scene_atr_pull, //
// reference atr, xyz are considered 0.0 // reference atr, xyz are considered 0.0 not anymore?
scene_QuadClt, // final QuadCLT scene_QuadClt, scene_QuadClt, // final QuadCLT scene_QuadClt,
reference_QuadClt, // final QuadCLT reference_QuadClt, reference_QuadClt, // final QuadCLT reference_QuadClt,
param_select_mod, // param_select, // final boolean[] param_select, param_select_mod, // param_select, // final boolean[] param_select,
......
...@@ -58,16 +58,23 @@ public class IntersceneLma { ...@@ -58,16 +58,23 @@ public class IntersceneLma {
private double [][] macrotile_centers = null; // (will be used to pull for regularization) private double [][] macrotile_centers = null; // (will be used to pull for regularization)
private double infinity_disparity = 0.1; // treat lower as infinity private double infinity_disparity = 0.1; // treat lower as infinity
private int num_samples = 0; private int num_samples = 0;
///////////////////////////////////////////////////////////
// thread_invariant is needed for LMA, otherwise even with the same parameter vector RMS may be slightly different
// keeping thread-variant where it is done once per LMA (during setup) - in normalizeWeights() and setSamplesWeights()
// See original code (with all switching between variant/invariant) in pre 12/17/2023
///////////////////////////////////////////////////////////
private boolean thread_invariant = true; // Do not use DoubleAdder, provide results not dependent on threads private boolean thread_invariant = true; // Do not use DoubleAdder, provide results not dependent on threads
private int num_components = 2; // 2 for 2D matching only,3 - include disparity private int num_components = 2; // 2 for 2D matching only,3 - include disparity
private double disparity_weight = 1.0; // relative weight of disparity errors private double disparity_weight = 1.0; // relative weight of disparity errors
public IntersceneLma( public IntersceneLma(
boolean thread_invariant, boolean thread_invariant,
double disparity_weight double disparity_weight
) { ) {
this.thread_invariant = thread_invariant;
this.num_components = (disparity_weight > 0) ? 3 : 2; this.num_components = (disparity_weight > 0) ? 3 : 2;
this.disparity_weight = disparity_weight; this.disparity_weight = disparity_weight;
this.thread_invariant= thread_invariant;
} }
public int getNumComponents() { public int getNumComponents() {
...@@ -133,22 +140,6 @@ public class IntersceneLma { ...@@ -133,22 +140,6 @@ public class IntersceneLma {
public String [] printOldNew(boolean allvectors, int w, int d) { public String [] printOldNew(boolean allvectors, int w, int d) {
return printOldNew(allvectors, 0, w, d); return printOldNew(allvectors, 0, w, d);
/*
String fmt1 = String.format("%%%d.%df", w+2,d+2); // more for the differences
ArrayList<String> lines = new ArrayList<String>();
for (int n = ErsCorrection.DP_DVAZ; n < ErsCorrection.DP_NUM_PARS; n+=3) {
boolean adj = false;
for (int i = 0; i <3; i++) adj |= par_mask[n+i];
if (allvectors || adj) {
String line = printNameV3(n, false, w,d)+" (" + getCompareType()+
" "+printNameV3(n, true, w,d)+")";
line += ", diff_last="+String.format(fmt1, getV3Diff(n)[0]);
line += ", diff_first="+String.format(fmt1, getV3Diff(n)[1]);
lines.add(line);
}
}
return lines.toArray(new String[lines.size()]);
*/
} }
public String [] printOldNew(boolean allvectors, int mode, int w, int d) { public String [] printOldNew(boolean allvectors, int mode, int w, int d) {
...@@ -198,29 +189,16 @@ public class IntersceneLma { ...@@ -198,29 +189,16 @@ public class IntersceneLma {
public String getCompareType(int mode) { public String getCompareType(int mode) {
boolean is_pull = (mode == 0) ? (parameters_pull != null) : (mode > 1); boolean is_pull = (mode == 0) ? (parameters_pull != null) : (mode > 1);
// return (parameters_pull != null)? "pull": "was"; return is_pull? "pull": "was ";
return is_pull? "pull": "was";
} }
public String printNameV3(int indx, boolean initial, int w, int d) { public String printNameV3(int indx, boolean initial, int w, int d) {
return printNameV3(indx, initial, 0, w, d); return printNameV3(indx, initial, 0, w, d);
/*
double [] full_vector = initial?
(use_pull? getFullVector(parameters_pull) : backup_parameters_full):
getFullVector(parameters_vector);
double [] vector = new double[3];
for (int i = 0; i <3; i++) {
vector[i] = full_vector[indx + i];
}
String name = ErsCorrection.DP_VECTORS_NAMES[indx];
return printNameV3(name, vector, w, d);
*/
} }
public String printNameV3(int indx, boolean initial, int mode, int w, int d) { public String printNameV3(int indx, boolean initial, int mode, int w, int d) {
// mode: 0 - auto, 1 - was, 2 - pull // mode: 0 - auto, 1 - was, 2 - pull
boolean use_pull = (mode == 0) ? (parameters_pull != null) : (mode > 1); boolean use_pull = (mode == 0) ? (parameters_pull != null) : (mode > 1);
// boolean use_pull = parameters_pull != null;
double [] full_vector = initial? double [] full_vector = initial?
(use_pull? getFullVector(parameters_pull) : backup_parameters_full): (use_pull? getFullVector(parameters_pull) : backup_parameters_full):
getFullVector(parameters_vector); getFullVector(parameters_vector);
...@@ -255,7 +233,7 @@ public class IntersceneLma { ...@@ -255,7 +233,7 @@ public class IntersceneLma {
final double [] scene_atr0, // camera orientation relative to world frame (or null to use instance) final double [] scene_atr0, // camera orientation relative to world frame (or null to use instance)
final double [] scene_xyz_pull, // if both are not null, specify target values to pull to final double [] scene_xyz_pull, // if both are not null, specify target values to pull to
final double [] scene_atr_pull, // final double [] scene_atr_pull, //
// reference atr, xyz are considered 0.0 // reference atr, xyz are considered 0.0 - not anymore?
final QuadCLT scene_QuadClt, final QuadCLT scene_QuadClt,
final QuadCLT reference_QuadClt, final QuadCLT reference_QuadClt,
final boolean[] param_select, final boolean[] param_select,
...@@ -360,6 +338,7 @@ public class IntersceneLma { ...@@ -360,6 +338,7 @@ public class IntersceneLma {
scene_QuadClt, // final QuadCLT scene_QuadClt, scene_QuadClt, // final QuadCLT scene_QuadClt,
reference_QuadClt, // final QuadCLT reference_QuadClt, reference_QuadClt, // final QuadCLT reference_QuadClt,
debug_level); // final int debug_level) debug_level); // final int debug_level)
// Why y_vector starts with initial value of fx???
y_vector = fx.clone(); y_vector = fx.clone();
for (int i = 0; i < vector_XYSDS.length; i++) { for (int i = 0; i < vector_XYSDS.length; i++) {
if (vector_XYSDS[i] != null){ if (vector_XYSDS[i] != null){
...@@ -371,6 +350,13 @@ public class IntersceneLma { ...@@ -371,6 +350,13 @@ public class IntersceneLma {
} }
} }
if (parameters_pull != null){
for (int i = 0; i < par_indices.length; i++) {
// y_vector [i + num_components * macrotile_centers.length] += parameters_pull[i]; // - parameters_initial[i]; // scale will be combined with weights
y_vector [i + num_components * macrotile_centers.length] = parameters_pull[i]; // - parameters_initial[i]; // scale will be combined with weights
}
}
last_rms = new double [2]; last_rms = new double [2];
last_ymfx = getYminusFxWeighted( last_ymfx = getYminusFxWeighted(
fx, // final double [] fx, fx, // final double [] fx,
...@@ -516,18 +502,6 @@ public class IntersceneLma { ...@@ -516,18 +502,6 @@ public class IntersceneLma {
Matrix wjtjlambda = new Matrix(getWJtJlambda( Matrix wjtjlambda = new Matrix(getWJtJlambda(
lambda, // *10, // temporary lambda, // *10, // temporary
this.last_jt)); // double [][] jt) this.last_jt)); // double [][] jt)
if (debug_level > 3) {
try {
System.out.println("getFxDerivs(): getChecksum(this.y_vector)="+ getChecksum(this.y_vector));
System.out.println("getFxDerivs(): getChecksum(this.weights)="+ getChecksum(this.weights));
System.out.println("getFxDerivs(): getChecksum(this.last_ymfx)="+ getChecksum(this.last_ymfx));
System.out.println("getFxDerivs(): getChecksum(y_minus_fx_weighted)="+getChecksum(y_minus_fx_weighted));
System.out.println("getFxDerivs(): getChecksum(wjtjlambda)= "+getChecksum(wjtjlambda));
} catch (NoSuchAlgorithmException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
if (debug_level>2) { if (debug_level>2) {
System.out.println("JtJ + lambda*diag(JtJ"); System.out.println("JtJ + lambda*diag(JtJ");
...@@ -554,16 +528,6 @@ public class IntersceneLma { ...@@ -554,16 +528,6 @@ public class IntersceneLma {
System.out.println("Jt * (y-fx)"); System.out.println("Jt * (y-fx)");
jty.print(18, 6); jty.print(18, 6);
} }
if (debug_level > 2) {
try {
System.out.println("getFxDerivs(): getChecksum(jtjl_inv)="+getChecksum(jtjl_inv));
System.out.println("getFxDerivs(): getChecksum(jty)= "+getChecksum(jty));
} catch (NoSuchAlgorithmException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
Matrix mdelta = jtjl_inv.times(jty); Matrix mdelta = jtjl_inv.times(jty);
...@@ -578,18 +542,6 @@ public class IntersceneLma { ...@@ -578,18 +542,6 @@ public class IntersceneLma {
for (int i = 0; i < parameters_vector.length; i++) { for (int i = 0; i < parameters_vector.length; i++) {
new_vector[i] += scale * delta[i]; new_vector[i] += scale * delta[i];
} }
if (debug_level > 2) {
try {
System.out.println("getFxDerivs(): getChecksum(mdelta)= "+getChecksum(mdelta));
System.out.println("getFxDerivs(): getChecksum(delta)= "+getChecksum(delta));
System.out.println("getFxDerivs(): getChecksum(parameters_vector)= "+getChecksum(parameters_vector));
System.out.println("getFxDerivs(): getChecksum(new_vector)= "+getChecksum(new_vector));
} catch (NoSuchAlgorithmException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
double [] fx = getFxDerivs( double [] fx = getFxDerivs(
new_vector, // double [] vector, new_vector, // double [] vector,
...@@ -666,71 +618,45 @@ public class IntersceneLma { ...@@ -666,71 +618,45 @@ public class IntersceneLma {
final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX); final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX);
final AtomicInteger ai = new AtomicInteger(0); final AtomicInteger ai = new AtomicInteger(0);
final AtomicInteger ati = new AtomicInteger(0);
final double [] sw_arr = new double [threads.length];
double sum_weights; double sum_weights;
if (thread_invariant) { final int num_types = (num_components > 2) ? 2 : 1;
final int num_types = (num_components > 2) ? 2 : 1; for (int ithread = 0; ithread < threads.length; ithread++) {
final double [] sw_arr = new double [num_types * vector_XYSDS.length]; threads[ithread] = new Thread() {
for (int ithread = 0; ithread < threads.length; ithread++) { public void run() {
threads[ithread] = new Thread() { int thread_num = ati.getAndIncrement();
public void run() { for (int iMTile = ai.getAndIncrement(); iMTile < vector_XYSDS.length; iMTile = ai.getAndIncrement()) if (vector_XYSDS[iMTile] != null){
for (int iMTile = ai.getAndIncrement(); iMTile < vector_XYSDS.length; iMTile = ai.getAndIncrement()) if (vector_XYSDS[iMTile] != null){ double w = vector_XYSDS[iMTile][2];
double w = vector_XYSDS[iMTile][2]; if (Double.isNaN(w)) {
if (Double.isNaN(w)) { w = 0;
w = 0;
}
weights[num_components * iMTile] = w;
sw_arr[num_types * iMTile] = 2*w;
//disparity_weight
if (num_types > 1) {
w = vector_XYSDS[iMTile][4] * disparity_weight;
if (Double.isNaN(w) || Double.isNaN(vector_XYSDS[iMTile][3])) {
w = 0;
vector_XYSDS[iMTile][3] = 0.0;
}
weights[num_components * iMTile + 2] = w;
sw_arr[num_types * iMTile + 1] = w;
}
} }
} weights[num_components * iMTile] = w;
}; sw_arr[thread_num] += 2*w;
} //disparity_weight
ImageDtt.startAndJoin(threads); if (num_types > 1) {
sum_weights = 0.0;
for (double w:sw_arr) {
sum_weights += w;
}
} else {
final DoubleAdder asum_weight = new DoubleAdder();
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int iMTile = ai.getAndIncrement(); iMTile < vector_XYSDS.length; iMTile = ai.getAndIncrement()) if (vector_XYSDS[iMTile] != null){
double w = vector_XYSDS[iMTile][2];
if (Double.isNaN(w)) {
w = 0;
}
weights[num_components * iMTile] = w;
asum_weight.add(w);
w = vector_XYSDS[iMTile][4] * disparity_weight; w = vector_XYSDS[iMTile][4] * disparity_weight;
if (Double.isNaN(w) || Double.isNaN(vector_XYSDS[iMTile][3])) { if (Double.isNaN(w) || Double.isNaN(vector_XYSDS[iMTile][3])) {
w = 0; w = 0;
vector_XYSDS[iMTile][3] = 0.0; vector_XYSDS[iMTile][3] = 0.0;
} }
weights[num_components * iMTile + 2] = w; weights[num_components * iMTile + 2] = w;
asum_weight.add(2*w); sw_arr[thread_num] += w;
} }
} }
}; }
} };
ImageDtt.startAndJoin(threads); }
sum_weights = asum_weight.sum(); ImageDtt.startAndJoin(threads);
sum_weights = 0.0;
for (double w:sw_arr) {
sum_weights += w;
} }
if (sum_weights <= 1E-8) { if (sum_weights <= 1E-8) {
System.out.println("!!!!!! setSamplesWeights(): sum_weights=="+sum_weights+" <= 1E-8"); System.out.println("!!!!!! setSamplesWeights(): sum_weights=="+sum_weights+" <= 1E-8");
} }
ai.set(0); ai.set(0);
// final double s = 0.5/sum_weights; final double s = 1.0/sum_weights; // Was 0.5 - already taken care of
final double s = 1.0/sum_weights; // already taken care of
for (int ithread = 0; ithread < threads.length; ithread++) { for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() { threads[ithread] = new Thread() {
public void run() { public void run() {
...@@ -738,7 +664,7 @@ public class IntersceneLma { ...@@ -738,7 +664,7 @@ public class IntersceneLma {
weights[num_components * iMTile] *= s; weights[num_components * iMTile] *= s;
weights[num_components * iMTile + 1] = weights[num_components * iMTile]; weights[num_components * iMTile + 1] = weights[num_components * iMTile];
if (num_components > 2) { if (num_components > 2) {
weights[num_components * iMTile + 2] *=s; weights[num_components * iMTile + 2] *= s;
} }
} }
} }
...@@ -751,52 +677,35 @@ public class IntersceneLma { ...@@ -751,52 +677,35 @@ public class IntersceneLma {
private void normalizeWeights() private void normalizeWeights()
{ {
final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX); final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX);
final AtomicInteger ai = new AtomicInteger(0); final AtomicInteger ai = new AtomicInteger(0);
double full_weight, sum_weight_pure; final AtomicInteger ati = new AtomicInteger(0);
if (thread_invariant) { final double [] sum_weight = new double [threads.length];
sum_weight_pure = 00; ati.set(0);
for (int i = 0; i < num_samples; i++) { for (int ithread = 0; ithread < threads.length; ithread++) {
if (Double.isNaN(weights[i])) { threads[ithread] = new Thread() {
System.out.println("normalizeWeights(): weights["+i+"] == NaN: 1"); public void run() {
weights[i] = 0; int thread_num = ati.getAndIncrement();
} for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()){
sum_weight_pure += weights[i]; if (Double.isNaN(weights[i])) {
} System.out.println("normalizeWeights(): weights["+i+"== NaN: 2");
full_weight = sum_weight_pure; weights[i] = 0;
for (int i = 0; i < par_indices.length; i++) {
int indx = num_samples + i;
if (Double.isNaN(weights[indx])) {
System.out.println("normalizeWeights(): weights["+indx+"] == NaN: 1.5, i="+i);
weights[indx] = 0;
}
full_weight += weights[indx];
}
} else {
final DoubleAdder asum_weight = new DoubleAdder();
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()){
if (Double.isNaN(weights[i])) {
System.out.println("normalizeWeights(): weights["+i+"== NaN: 2");
weights[i] = 0;
}
asum_weight.add(weights[i]);
} }
sum_weight[thread_num] += weights[i];
} }
};
}
ImageDtt.startAndJoin(threads);
sum_weight_pure = asum_weight.sum();
for (int i = 0; i < par_indices.length; i++) {
int indx = num_samples + i;
if (Double.isNaN(weights[indx])) {
System.out.println("normalizeWeights(): weights["+indx+"] == NaN: 1.5, i="+i);
weights[indx] = 0;
} }
asum_weight.add(weights[indx]); };
}
ImageDtt.startAndJoin(threads);
double sum_weight_pure = 0;
for (double sw:sum_weight) sum_weight_pure += sw;
double full_weight = sum_weight_pure;
for (int i = 0; i < par_indices.length; i++) {
int indx = num_samples + i;
if (Double.isNaN(weights[indx])) {
System.out.println("normalizeWeights(): weights["+indx+"] == NaN: 1.5, i="+i);
weights[indx] = 0;
} }
full_weight = asum_weight.sum(); full_weight += weights[indx];
} }
pure_weight = sum_weight_pure/full_weight; pure_weight = sum_weight_pure/full_weight;
final double s = 1.0/full_weight; final double s = 1.0/full_weight;
...@@ -939,25 +848,14 @@ public class IntersceneLma { ...@@ -939,25 +848,14 @@ public class IntersceneLma {
} }
// pull to the initial parameter values // pull to the initial parameter values
for (int i = 0; i < par_indices.length; i++) { for (int i = 0; i < par_indices.length; i++) {
fx [i + num_components * macrotile_centers.length] = vector[i]; // - parameters_initial[i]; // scale will be combined with weights fx [i + num_samples] = vector[i]; // - parameters_initial[i]; // scale will be combined with weights
jt[i][i + num_components * macrotile_centers.length] = 1.0; // scale will be combined with weights jt[i][i + num_samples] = 1.0; // scale will be combined with weights
} }
if (parameters_pull != null){ /// if (parameters_pull != null){
for (int i = 0; i < par_indices.length; i++) { /// for (int i = 0; i < par_indices.length; i++) {
fx [i + num_components * macrotile_centers.length] -= parameters_pull[i]; // - parameters_initial[i]; // scale will be combined with weights /// fx [i + num_samples] -= parameters_pull[i]; // - parameters_initial[i]; // scale will be combined with weights
} /// }
} /// }
if (debug_level > 3) {
try {
System.out.println ("getFxDerivs(): getChecksum(fx)="+getChecksum(fx));
if (jt != null) {
System.out.println("getFxDerivs(): getChecksum(jt)="+getChecksum(jt));
}
} catch (NoSuchAlgorithmException | IOException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
}
return fx; return fx;
} }
...@@ -1004,51 +902,92 @@ public class IntersceneLma { ...@@ -1004,51 +902,92 @@ public class IntersceneLma {
final double [] fx, final double [] fx,
final double [] rms_fp // null or [2] final double [] rms_fp // null or [2]
) { ) {
if (thread_invariant) {
return getYminusFxWeightedInvariant(fx,rms_fp); // null or [2]
} else {
return getYminusFxWeightedFast (fx,rms_fp); // null or [2]
}
}
private double [] getYminusFxWeightedInvariant(
final double [] fx,
final double [] rms_fp // null or [2]
) {
final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX); final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX);
final AtomicInteger ai = new AtomicInteger(0); final AtomicInteger ai = new AtomicInteger(0);
final double [] wymfw = new double [fx.length]; final double [] wymfw = new double [fx.length];
double s_rms; double s_rms;
if (thread_invariant) { final double [] l2_arr = new double [num_samples];
final double [] l2_arr = new double [num_samples]; for (int ithread = 0; ithread < threads.length; ithread++) {
for (int ithread = 0; ithread < threads.length; ithread++) { threads[ithread] = new Thread() {
threads[ithread] = new Thread() { public void run() {
public void run() { for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) { double d = y_vector[i] - fx[i];
double d = y_vector[i] - fx[i]; double wd = d * weights[i];
double wd = d * weights[i]; if (Double.isNaN(wd)) {
if (Double.isNaN(wd)) { System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+ ", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
}
//double l2 = d * wd;
l2_arr[i] = d * wd;
wymfw[i] = wd;
} }
//double l2 = d * wd;
l2_arr[i] = d * wd;
wymfw[i] = wd;
} }
}; }
} };
ImageDtt.startAndJoin(threads); }
s_rms = 0.0; ImageDtt.startAndJoin(threads);
for (double l2:l2_arr) { s_rms = 0.0;
s_rms += l2; for (double l2:l2_arr) {
} s_rms += l2;
} else { }
final DoubleAdder asum_weight = new DoubleAdder(); double rms_pure = Math.sqrt(s_rms/pure_weight);
for (int ithread = 0; ithread < threads.length; ithread++) { for (int i = 0; i < par_indices.length; i++) {
threads[ithread] = new Thread() { int indx = i + num_samples;
public void run() { double d = y_vector[indx] - fx[indx]; // fx[indx] == vector[i]
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) { double wd = d * weights[indx];
double d = y_vector[i] - fx[i]; s_rms += d * wd;
double wd = d * weights[i]; wymfw[indx] = wd;
double l2 = d * wd; }
wymfw[i] = wd; double rms = Math.sqrt(s_rms); // assuming sum_weights == 1.0; /pure_weight); shey should be re-normalized after adding regularization
asum_weight.add(l2); if (rms_fp != null) {
rms_fp[0] = rms;
rms_fp[1] = rms_pure;
}
return wymfw;
}
private double [] getYminusFxWeightedFast(
final double [] fx,
final double [] rms_fp // null or [2]
) {
final Thread[] threads = ImageDtt.newThreadArray(QuadCLT.THREADS_MAX);
final AtomicInteger ai = new AtomicInteger(0);
final double [] wymfw = new double [fx.length];
final AtomicInteger ati = new AtomicInteger(0);
final double [] l2_arr = new double [threads.length];
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
int thread_num = ati.getAndIncrement();
for (int i = ai.getAndIncrement(); i < num_samples; i = ai.getAndIncrement()) if (weights[i] > 0) {
double d = y_vector[i] - fx[i];
double wd = d * weights[i];
if (Double.isNaN(wd)) {
System.out.println("getYminusFxWeighted(): weights["+i+"]="+weights[i]+", wd="+wd+
", y_vector[i]="+y_vector[i]+", fx[i]="+fx[i]);
} }
//double l2 = d * wd;
l2_arr[thread_num] += d * wd;
wymfw[i] = wd;
} }
}; }
} };
ImageDtt.startAndJoin(threads); }
s_rms = asum_weight.sum(); ImageDtt.startAndJoin(threads);
double s_rms = 0.0;
for (double l2:l2_arr) {
s_rms += l2;
} }
double rms_pure = Math.sqrt(s_rms/pure_weight); double rms_pure = Math.sqrt(s_rms/pure_weight);
for (int i = 0; i < par_indices.length; i++) { for (int i = 0; i < par_indices.length; i++) {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment