Commit 41015a35 authored by Andrey Filippov's avatar Andrey Filippov

multithreaded LMA

parent c4070b30
......@@ -4165,6 +4165,15 @@ if (MORE_BUTTONS) {
double rms= FOCUSING_FIELD.calcErrorDiffY(focusing_fx, false);
double rms_pure= FOCUSING_FIELD.calcErrorDiffY(focusing_fx, true);
System.out.println("rms="+rms+", rms_pure="+rms_pure+" - with old parameters may be well off.");
if (FOCUS_MEASUREMENT_PARAMETERS.scanRunLMA){
FOCUSING_FIELD.setAdjustMode(false);
boolean OK=FOCUSING_FIELD.LevenbergMarquardt(
null, // measurement
false, // true, // open dialog
true,// boolean autoSel,
DEBUG_LEVEL); //boolean openDialog, int debugLevel){
if (OK) saveCurrentConfig();
}
remoteNotifyComplete();
return;
}
......
......@@ -54,6 +54,7 @@ import org.apache.commons.configuration.XMLConfiguration;
//import Distortions.LMAArrays; // may still reuse?
import Jama.LUDecomposition;
import Jama.Matrix;
......@@ -236,7 +237,13 @@ public class FocusingField {
private double firstRMS=-1.0; // RMS before current series of LMA started
private double firstRMSPure=-1.0; // RMS before current series of LMA started
public int threadsMax=0; // 0 - old code
private boolean multiJacobian=true; // to try multithreaded mode
public void setThreads(int num){
this.threadsMax=num;
}
public void setDefaults(){
goodCalibratedSamples=null;
sensorWidth= 2592;
......@@ -1942,9 +1949,187 @@ public double [] createFXandJacobian( double [] vector, boolean createJacobian){
commitParameterVector(vector);
return createFXandJacobian(createJacobian);
}
//multiJacobian
public double [] createFXandJacobian(boolean createJacobian){
if (multiJacobian && (threadsMax>0)) return createFXandJacobianMulti(createJacobian);
else return createFXandJacobianSingle(createJacobian);
}
public class PartialFXJac{
public int index; // measurement number
public double f;
public double [] jac=null;
public PartialFXJac (
int index,
double f,
double [] jac){ //, int num){
this.index=index;
this.f=f;
this.jac=jac;
// if (num>=0) jac=new double [num];
// else jac=null;
// jac=null;
}
}
public double [] createFXandJacobianMulti(
final boolean createJacobian
){
long startTime=System.nanoTime();
int numCorrPar=fieldFitting.getNumberOfCorrParameters();
boolean [] selChannels=fieldFitting.getSelectedChannels();
final int [] selChanIndices= new int[selChannels.length];
selChanIndices[0]=0;
for (int i=1;i<selChanIndices.length;i++){
selChanIndices[i]= selChanIndices[i-1]+(selChannels[i-1]?1:0);
}
final int numPars=fieldFitting.getNumberOfParameters(sagittalMaster);
int numRegPars=fieldFitting.getNumberOfRegularParameters(sagittalMaster);
final int numSelChn=fieldFitting.getNumberOfChannels();
final Thread[] threads = newThreadArray(threadsMax);
final ArrayList<ArrayList<PartialFXJac>> fxList = new ArrayList<ArrayList<PartialFXJac>>();
for (int ithread = 0; ithread < threads.length; ithread++) {
fxList.add(new ArrayList<PartialFXJac>());
}
// create list of indices of measurements corresponding to new timestamp/sample
String prevTimeStamp="";
double prevPx=-1,prevPy=-1;
final ArrayList<Integer> measIndicesList = new ArrayList<Integer>(dataVector.length/getNumChannels());
for (int n=0;n<dataVector.length;n++){
MeasuredSample ms=dataVector[n];
if (!ms.timestamp.equals(prevTimeStamp) || (ms.px!=prevPx) || (ms.py!=prevPy)){
measIndicesList.add(new Integer(n));
}
}
// measIndicesList.add(new Integer(dataVector.length));
final AtomicInteger measIndex = new AtomicInteger(0);
final AtomicInteger threadIndexAtomic = new AtomicInteger(0);
// final boolean [] falseFalse={false,false};
final boolean [] centerSelect=correct_measurement_ST?fieldFitting.getCenterSelect():null; //falseFalse;
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
int threadIndex=threadIndexAtomic.getAndIncrement();
fxList.get(threadIndex).clear(); // not needed
double [][] derivs;
for (int startMeasIndex=measIndex.getAndIncrement(); startMeasIndex<measIndicesList.size();startMeasIndex=measIndex.getAndIncrement()){
int startMeas=measIndicesList.get(startMeasIndex);
int endMeas=(startMeasIndex==(measIndicesList.size()-1))?dataVector.length:measIndicesList.get(startMeasIndex+1);
MeasuredSample ms=dataVector[startMeas];
derivs=createJacobian?(new double[numSelChn][]):null;
double [] subData=fieldFitting.getValsDerivatives(
ms.sampleIndex,
sagittalMaster,
ms.motors, // 3 motor coordinates
ms.px, // pixel x
ms.py, // pixel y
derivs);
for (int n=startMeas;n<endMeas;n++){
int chn=selChanIndices[ms.channel];
if (centerSelect!=null){
int np=0;
for (int i=0;i<2;i++) if (centerSelect[i]){
derivs[chn][np++]-=ms.dPxyc[i]; // subtract, as effect is opposite to fX
}
}
PartialFXJac partialFXJac = new PartialFXJac(n,
subData[chn],
derivs[chn]);
fxList.get(threadIndex).add(partialFXJac);
}
}
public double [] createFXandJacobian(boolean createJacobian){
}
};
}
startAndJoin(threads);
if (debugLevel>1) System.out.println("#1 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-startTime),5));
// Combibe results
double [] fx=new double[dataVector.length + numCorrPar ];
if (createJacobian) {
jacobian=new double [numPars][dataVector.length+numCorrPar];
for (double [] row : jacobian) Arrays.fill(row, 0.0);
}
for (ArrayList<PartialFXJac> partilaList:fxList){
for (PartialFXJac partialFXJac:partilaList){
int n=partialFXJac.index;
fx[n]=partialFXJac.f;
if (createJacobian){
for (int i=0;i<numPars;i++){
jacobian[i][n]=partialFXJac.jac[i];
}
}
}
}
if (debugLevel>1) System.out.println("#2 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-startTime),5));
if (createJacobian && (fieldFitting.sampleCorrChnParIndex!=null)) {
// add mutual dependence of correction parameters. first - values (fx)
int index=dataVector.length; // add to the end of vector
int numSamples=getNumSamples();
for (int chn=0;chn<fieldFitting.sampleCorrChnParIndex.length;chn++) if (fieldFitting.sampleCorrChnParIndex[chn]!=null) {
for (int np=0;np<fieldFitting.sampleCorrChnParIndex[chn].length;np++){
int pindex=fieldFitting.sampleCorrChnParIndex[chn][np];
if (pindex>=0) {
for (int i=0;i<numSamples;i++){
double f=0.0;
for (int j=0;j<numSamples;j++){
f+=fieldFitting.sampleCorrVector[pindex+j]*fieldFitting.sampleCorrCrossWeights[chn][np][i][j];
}
fx[index]=f;
// f+=fieldFitting.sampleCorrVector[pindex+i]
if (createJacobian) {
for (int j=0;j<numSamples;j++){
jacobian[numRegPars+pindex+j][index]=fieldFitting.sampleCorrCrossWeights[chn][np][i][j];
}
}
index++;
}
}
}
}
}
if (debugLevel>1) System.out.println("#3 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-startTime),5));
if (createJacobian && (debugLevel>1)){
if (debugPoint>=0) debugJacobianPoint(debugPoint);
if (debugParameter>=0) debugJacobianParameter(debugParameter);
}
if (debugLevel>1) System.out.println("#4 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-startTime),5));
return fx;
}
/* Create a Thread[] array as large as the number of processors available.
* From Stephan Preibisch's Multithreading.java class. See:
* http://repo.or.cz/w/trakem2.git?a=blob;f=mpi/fruitfly/general/MultiThreading.java;hb=HEAD
*/
private Thread[] newThreadArray(int maxCPUs) {
int n_cpus = Runtime.getRuntime().availableProcessors();
if (n_cpus>maxCPUs)n_cpus=maxCPUs;
return new Thread[n_cpus];
}
/* Start all given threads and wait on each of them until all are done.
* From Stephan Preibisch's Multithreading.java class. See:
* http://repo.or.cz/w/trakem2.git?a=blob;f=mpi/fruitfly/general/MultiThreading.java;hb=HEAD
*/
private static void startAndJoin(Thread[] threads)
{
for (int ithread = 0; ithread < threads.length; ++ithread)
{
threads[ithread].setPriority(Thread.NORM_PRIORITY);
threads[ithread].start();
}
try
{
for (int ithread = 0; ithread < threads.length; ++ithread)
threads[ithread].join();
} catch (InterruptedException ie)
{
throw new RuntimeException(ie);
}
}
public double [] createFXandJacobianSingle(boolean createJacobian){
int numCorrPar=fieldFitting.getNumberOfCorrParameters();
double [] fx=new double[dataVector.length + numCorrPar ];
double [][] derivs=null;
......@@ -1967,11 +2152,6 @@ public double [] createFXandJacobian(boolean createJacobian){
double prevPx=-1,prevPy=-1;
for (int n=0;n<dataVector.length;n++){
MeasuredSample ms=dataVector[n];
// int saveDebugLevel=debugLevel;
// if (n==9) {
// System.out.println("createFXandJacobian(): n="+n);
// debugLevel=10;
// }
if (!ms.timestamp.equals(prevTimeStamp) || (ms.px!=prevPx) || (ms.py!=prevPy)){
subData=fieldFitting.getValsDerivatives(
ms.sampleIndex,
......@@ -1984,7 +2164,6 @@ public double [] createFXandJacobian(boolean createJacobian){
prevPx=ms.px;
prevPy=ms.py;
}
// debugLevel=saveDebugLevel; // restore debugLevel
fx[n]=subData[selChanIndices[ms.channel]];
if (createJacobian) {
double [] thisDerivs=derivs[selChanIndices[ms.channel]];
......@@ -2011,6 +2190,8 @@ public double [] createFXandJacobian(boolean createJacobian){
}
}
}
}
if (createJacobian) {
// add mutual dependence of correction parameters. first - values (fx)
// System.out.println("Using sampleCorrVector 1");
int index=dataVector.length; // add to the end of vector
......@@ -2611,6 +2792,11 @@ d_s2/d_x0= 2*delta_x*delta_y^2/r2^2
public LMAArrays calculateJacobianArrays(double [] fX){
if (multiJacobian && (threadsMax>0)) return calculateJacobianArraysMulti(fX);
else return calculateJacobianArraysSingle(fX);
}
public LMAArrays calculateJacobianArraysSingle(double [] fX){
// calculate JtJ
double [] diff=calcYminusFx(fX);
int numPars=this.jacobian.length; // number of parameters to be adjusted
......@@ -2625,7 +2811,6 @@ d_s2/d_x0= 2*delta_x*delta_y^2/r2^2
for (int k=0;k<length;k++) JtByJmod[i][j]+=this.jacobian[i][k]*this.jacobian[j][k];
}
for (int i=0;i<numPars;i++) { // subtract lambda*diagonal , fill the symmetrical half below the diagonal
// JtByJmod[i][i]+=lambda*JtByJmod[i][i]; //Marquardt mod
for (int j=0;j<i;j++) JtByJmod[i][j]= JtByJmod[j][i]; // it is symmetrical matrix, just copy
}
for (int i=0;i<numPars;i++) {
......@@ -2642,6 +2827,76 @@ d_s2/d_x0= 2*delta_x*delta_y^2/r2^2
lMAArrays.jTByDiff=JtByDiff;
return lMAArrays;
}
public LMAArrays calculateJacobianArraysMulti(double [] fX){
// calculate JtJ
final double [] diff=calcYminusFx(fX);
final int numPars=this.jacobian.length; // number of parameters to be adjusted
// int length=diff.length; // should be the same as this.jacobian[0].length
final double [][] JtByJmod=new double [numPars][numPars]; //Transposed Jacobian multiplied by Jacobian
final double [] JtByDiff=new double [numPars];
final double [] fWeights=this.dataWeights;
// final double [][] fJacobian=this.jacobian;
final AtomicInteger lineAtomic = new AtomicInteger(0);
final Thread[] threads = newThreadArray(threadsMax);
for (int ithread = 0; ithread < threads.length; ithread++) {
threads[ithread] = new Thread() {
public void run() {
for (int line=lineAtomic.getAndIncrement(); line<numPars;line=lineAtomic.getAndIncrement()){
double [] sLine=jacobian[line];
if (fWeights!=null){
sLine=jacobian[line].clone();
for (int i=0;i<sLine.length;i++) sLine[i]*=fWeights[i];
}
for (int line2=line;line2<numPars;line2++){
double d=0;
for (int i=0;i<sLine.length;i++) if (sLine[i]!=0.0){
d+=sLine[i]*jacobian[line2][i];
}
JtByJmod[line][line2]=d;
}
double d=0;
for (int i=0;i<sLine.length;i++) if (sLine[i]!=0.0){
d+=sLine[i]*diff[i];
}
JtByDiff[line]=d;
}
} // public void run() {
};
}
startAndJoin(threads);
/*
for (int i=0;i<numPars;i++) for (int j=i;j<numPars;j++){
JtByJmod[i][j]=0.0;
if (this.dataWeights!=null)
for (int k=0;k<length;k++) JtByJmod[i][j]+=this.jacobian[i][k]*this.jacobian[j][k]*this.dataWeights[k];
else
for (int k=0;k<length;k++) JtByJmod[i][j]+=this.jacobian[i][k]*this.jacobian[j][k];
}
*/
for (int i=0;i<numPars;i++) { // subtract lambda*diagonal , fill the symmetrical half below the diagonal
for (int j=0;j<i;j++) JtByJmod[i][j]= JtByJmod[j][i]; // it is symmetrical matrix, just copy
}
/*
for (int i=0;i<numPars;i++) {
JtByDiff[i]=0.0;
if (this.dataWeights!=null)
for (int k=0;k<length;k++) JtByDiff[i]+=this.jacobian[i][k]*diff[k]*this.dataWeights[k];
else
for (int k=0;k<length;k++) JtByDiff[i]+=this.jacobian[i][k]*diff[k];
}
*/
LMAArrays lMAArrays = new LMAArrays();
lMAArrays.jTByJ=JtByJmod;
lMAArrays.jTByDiff=JtByDiff;
return lMAArrays;
}
public double [] solveLMA(
LMAArrays lMAArrays,
double lambda,
......@@ -2763,10 +3018,15 @@ d_s2/d_x0= 2*delta_x*delta_y^2/r2^2
String msg="initial Jacobian matrix calculation. Points:"+this.dataValues.length+" Parameters:"+this.currentVector.length;
if (debugLevel>1) System.out.println(msg);
if (this.updateStatus) IJ.showStatus(msg);
System.out.println("*** 1 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.currentfX=createFXandJacobian(this.currentVector, true); // is it always true here (this.jacobian==null)
System.out.println("*** 2 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.lMAArrays=calculateJacobianArrays(this.currentfX);
System.out.println("*** 3 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.currentRMS= calcErrorDiffY(this.currentfX,false);
System.out.println("*** 4 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.currentRMSPure=calcErrorDiffY(this.currentfX, true);
System.out.println("*** 5 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
msg=this.currentStrategyStep+": initial RMS="+IJ.d2s(this.currentRMS,8)+" (pure RMS="+IJ.d2s(this.currentRMSPure,8)+")"+
". Calculating next Jacobian. Points:"+this.dataValues.length+" Parameters:"+this.currentVector.length;
if (debugLevel>1) System.out.println(msg);
......@@ -2813,10 +3073,16 @@ d_s2/d_x0= 2*delta_x*delta_y^2/r2^2
// this.savedJacobian=this.jacobian;
this.savedLMAArrays=lMAArrays.clone();
this.jacobian=null; // not needed, just to catch bugs
if (debugLevel>1) System.out.println("*** 6 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.nextfX=createFXandJacobian(this.nextVector,true);
if (debugLevel>1) System.out.println("*** 7 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.lMAArrays=calculateJacobianArrays(this.nextfX);
if (debugLevel>1) System.out.println("*** 8 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.nextRMS= calcErrorDiffY(this.nextfX,false);
if (debugLevel>1) System.out.println("*** 9 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.nextRMSPure= calcErrorDiffY(this.nextfX,true);
if (debugLevel>1) System.out.println("*** 10 @ "+ IJ.d2s(0.000000001*(System.nanoTime()-this.startTime),5));
this.lastImprovements[1]=this.lastImprovements[0];
this.lastImprovements[0]=this.currentRMS-this.nextRMS;
String msg="currentRMS="+this.currentRMS+
......@@ -2937,6 +3203,8 @@ public boolean selectLMAParameters(boolean autoSel){
gd.addCheckbox("Show modified parameters", this.showParams);
gd.addCheckbox("Show disabled parameters", this.showDisabledParams);
gd.addCheckbox("Show per-sample correction parameters", this.showCorrectionParams);
gd.addNumericField("Maximal number of threads (0 - old code)", this.threadsMax, 0);
//threadsMax
// gd.addCheckbox("Reset all per-sample corrections to zero", resetCorrections);
......@@ -2969,6 +3237,7 @@ public boolean selectLMAParameters(boolean autoSel){
this.showParams= gd.getNextBoolean();
this.showDisabledParams= gd.getNextBoolean();
this.showCorrectionParams= gd.getNextBoolean();
this.threadsMax= (int) gd.getNextNumber();
// if (!keepCorrectionParameters) fieldFitting.resetSampleCorr();
return true;
}
......@@ -6669,7 +6938,7 @@ public boolean LevenbergMarquardt(
int [] motors, // 3 motor coordinates
double px, // pixel x
double py, // pixel y
double [][] deriv // array of (1..6[][], matching getNumberOfChannels) or null if derivatives are not required
double [][] deriv // array of (1..6][], matching getNumberOfChannels) or null if derivatives are not required
){
// if (sampleIndex==39) {
// System.out.print("?");
......
......@@ -307,6 +307,7 @@ public class LensAdjustment {
public boolean scanTiltEnable=false; //true; // enable scanning tilt
public boolean scanTiltReverse=false; // enable scanning tilt in both directions
public boolean scanMeasureLast=false; // Calculate PSF after last move (to original position)
public boolean scanRunLMA=true; // Calculate PSF after last move (to original position)
public int scanTiltRangeX=14336; // 4 periods
public int scanTiltRangeY=14336; // 4 periods
public int scanTiltStepsX=24;
......@@ -524,6 +525,7 @@ public class LensAdjustment {
boolean scanTiltEnable, //=true; // enable scanning tilt
boolean scanTiltReverse,
boolean scanMeasureLast,
boolean scanRunLMA,
int scanTiltRangeX, //=14336; // 4 periods
int scanTiltRangeY, //=14336; // 4 periods
int scanTiltStepsX, //=24;
......@@ -669,6 +671,7 @@ public class LensAdjustment {
this.scanTiltEnable=scanTiltEnable; //=true; // enable scanning tilt
this.scanTiltReverse=scanTiltReverse;
this.scanMeasureLast=scanMeasureLast;
this.scanRunLMA=scanRunLMA;
this.scanTiltRangeX=scanTiltRangeX; //, //=14336; // 4 periods
this.scanTiltRangeY=scanTiltRangeY; //, //=14336; // 4 periods
this.scanTiltStepsX=scanTiltStepsX; //=24;
......@@ -815,6 +818,7 @@ public class LensAdjustment {
this.scanTiltEnable, // enable scanning tilt
this.scanTiltReverse,
this.scanMeasureLast,
this.scanRunLMA,
this.scanTiltRangeX, // 4 periods
this.scanTiltRangeY, // 4 periods
this.scanTiltStepsX,
......@@ -966,6 +970,7 @@ public class LensAdjustment {
properties.setProperty(prefix+"scanTiltEnable",this.scanTiltEnable+""); // enable scanning tilt
properties.setProperty(prefix+"scanTiltReverse",this.scanTiltReverse+"");
properties.setProperty(prefix+"scanMeasureLast",this.scanMeasureLast+"");
properties.setProperty(prefix+"scanRunLMA",this.scanRunLMA+"");
properties.setProperty(prefix+"scanTiltRangeX",this.scanTiltRangeX+""); // 4 periods
properties.setProperty(prefix+"scanTiltRangeY",this.scanTiltRangeY+""); // 4 periods
properties.setProperty(prefix+"scanTiltStepsX",this.scanTiltStepsX+"");
......@@ -1226,6 +1231,9 @@ public class LensAdjustment {
if (properties.getProperty(prefix+"scanMeasureLast")!=null)
this.scanMeasureLast=Boolean.parseBoolean(properties.getProperty(prefix+"scanMeasureLast"));
if (properties.getProperty(prefix+"scanRunLMA")!=null)
this.scanRunLMA=Boolean.parseBoolean(properties.getProperty(prefix+"scanRunLMA"));
if (properties.getProperty(prefix+"scanTiltRangeX")!=null)
this.scanTiltRangeX=Integer.parseInt(properties.getProperty(prefix+"scanTiltRangeX"));
if (properties.getProperty(prefix+"scanTiltRangeY")!=null)
......@@ -1374,6 +1382,8 @@ public class LensAdjustment {
gd.addCheckbox ("Scan for tilt measurement (approximately preserving center)", this.scanTiltEnable);
gd.addCheckbox ("Scan for tilt measurement in both directions", this.scanTiltReverse);
gd.addCheckbox ("Calculate PSF after returning to the initial position", this.scanMeasureLast);
gd.addCheckbox ("Calculate model parameters after scanning", this.scanRunLMA);
gd.addNumericField("Full range of scanning motors tilting in X-direction", this.scanTiltRangeX, 0,7,"motors steps");
gd.addNumericField("Full range of scanning motors tilting in Y-direction", this.scanTiltRangeY, 0,7,"motors steps");
......@@ -1396,6 +1406,7 @@ public class LensAdjustment {
this.scanTiltEnable= gd.getNextBoolean();
this.scanTiltReverse= gd.getNextBoolean();
this.scanMeasureLast= gd.getNextBoolean();
this.scanRunLMA= gd.getNextBoolean();
this.scanTiltRangeX= (int) gd.getNextNumber();
this.scanTiltRangeY= (int) gd.getNextNumber();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment