package com.elphel.imagej.tileprocessor;

import java.util.ArrayList;
import java.util.Arrays;

import com.elphel.imagej.cameras.CLTParameters;

import Jama.Matrix;

/**
 * Classic-LMA structured global refinement path.
 *
 * <p>Phase-1 keeps numeric behavior stable by delegating the actual solve to
 * {@link IntersceneGlobalRefine}, while this class owns the classic LMA structure
 * ({@code prepareLMA/getFxDerivs/getYminusFxWeighted/lmaStep/runLma}) for incremental migration.
 */
public class IntersceneGlobalLmaRefine {

	private static final int[] SCENE_POSE_PARAM_ORDER = ErsCorrection.DP_XYZATR.clone();

	private static final class ClassicState {
		int[] activePoseParams = new int[0];
		int[] vectorSceneIndices = new int[0];
		int[] vectorParamIndices = new int[0];
		double[] parameterVector = new double[0];
		double[] parameterInitial = new double[0];
		double[] parameterPull = new double[0];
		double[] yVector = new double[0];
		double[] weights = new double[0];
		double[][] lastJt = null;
		double[] lastYminusFx = null;
		double[] lastRms = null;
		double[] initialRms = null;

		void clearLinearization() {
			lastJt = null;
			lastYminusFx = null;
			lastRms = null;
			initialRms = null;
		}
	}

	private final CLTParameters cltParameters;
	private final QuadCLT[] quadCLTs;
	private final QuadCLT centerCLT;
	private final int centerIndex;
	private final int earliestScene;
	private final int lastScene;
	private final double[][][] scenesXyzatr;
	private final double[][][] scenesXyzatrPull;
	private final boolean[] paramSelect;
	private final double[] paramRegweights;
	private final double[] paramLpf;
	private final double[] centerDisparity;
	private final boolean[] reliableRef;
	private final boolean disableErs;
	private final double mbMaxGain;
	private final IntersceneGlobalRefine.Options options;
	private final ClassicState classicState = new ClassicState();

	private IntersceneGlobalLmaRefine(
			final CLTParameters cltParameters,
			final QuadCLT[] quadCLTs,
			final QuadCLT centerCLT,
			final int centerIndex,
			final int earliestScene,
			final int lastScene,
			final double[][][] scenesXyzatr,
			final double[][][] scenesXyzatrPull,
			final boolean[] paramSelect,
			final double[] paramRegweights,
			final double[] paramLpf,
			final double[] centerDisparity,
			final boolean[] reliableRef,
			final boolean disableErs,
			final double mbMaxGain,
			final IntersceneGlobalRefine.Options options) {
		this.cltParameters = cltParameters;
		this.quadCLTs = quadCLTs;
		this.centerCLT = centerCLT;
		this.centerIndex = centerIndex;
		this.earliestScene = earliestScene;
		this.lastScene = lastScene;
		this.scenesXyzatr = scenesXyzatr;
		this.scenesXyzatrPull = scenesXyzatrPull;
		this.paramSelect = paramSelect;
		this.paramRegweights = paramRegweights;
		this.paramLpf = paramLpf;
		this.centerDisparity = centerDisparity;
		this.reliableRef = reliableRef;
		this.disableErs = disableErs;
		this.mbMaxGain = mbMaxGain;
		this.options = options;
	}

	public static IntersceneGlobalRefine.Result refineAllToReference(
			final CLTParameters cltParameters,
			final QuadCLT[] quadCLTs,
			final QuadCLT centerCLT,
			final int centerIndex,
			final int earliestScene,
			final int lastScene,
			final double[][][] scenesXyzatr,
			final double[][][] scenesXyzatrPull,
			final boolean[] paramSelect,
			final double[] paramRegweights,
			final double[] paramLpf,
			final double[] centerDisparity,
			final boolean[] reliableRef,
			final boolean disableErs,
			final double mbMaxGain,
			final IntersceneGlobalRefine.Options options,
			final int debugLevel) {
		final IntersceneGlobalLmaRefine solver = new IntersceneGlobalLmaRefine(
				cltParameters,
				quadCLTs,
				centerCLT,
				centerIndex,
				earliestScene,
				lastScene,
				scenesXyzatr,
				scenesXyzatrPull,
				paramSelect,
				paramRegweights,
				paramLpf,
				centerDisparity,
				reliableRef,
				disableErs,
				mbMaxGain,
				options);
		solver.prepareLMA(debugLevel);
		return solver.runLma(debugLevel);
	}

	private static boolean isPoseValid(final double[][] scenePose) {
		return (scenePose != null) &&
				(scenePose.length >= 2) &&
				(scenePose[0] != null) && (scenePose[0].length >= 3) &&
				(scenePose[1] != null) && (scenePose[1].length >= 3);
	}

	private static double getScenePoseParameter(
			final double[][] scenePose,
			final int dpIndex) {
		if (!isPoseValid(scenePose)) {
			return Double.NaN;
		}
		switch (dpIndex) {
		case ErsCorrection.DP_DSX:
			return scenePose[0][0];
		case ErsCorrection.DP_DSY:
			return scenePose[0][1];
		case ErsCorrection.DP_DSZ:
			return scenePose[0][2];
		case ErsCorrection.DP_DSAZ:
			return scenePose[1][0];
		case ErsCorrection.DP_DSTL:
			return scenePose[1][1];
		case ErsCorrection.DP_DSRL:
			return scenePose[1][2];
		default:
			return Double.NaN;
		}
	}

	@SuppressWarnings("unused")
	private static void setScenePoseParameter(
			final double[][] scenePose,
			final int dpIndex,
			final double value) {
		if (!isPoseValid(scenePose)) {
			return;
		}
		switch (dpIndex) {
		case ErsCorrection.DP_DSX:
			scenePose[0][0] = value;
			break;
		case ErsCorrection.DP_DSY:
			scenePose[0][1] = value;
			break;
		case ErsCorrection.DP_DSZ:
			scenePose[0][2] = value;
			break;
		case ErsCorrection.DP_DSAZ:
			scenePose[1][0] = value;
			break;
		case ErsCorrection.DP_DSTL:
			scenePose[1][1] = value;
			break;
		case ErsCorrection.DP_DSRL:
			scenePose[1][2] = value;
			break;
		default:
			break;
		}
	}

	private static double normalizeValue(final double value) {
		return Double.isFinite(value) ? value : 0.0;
	}

	private static double getArrayValue(
			final double[] values,
			final int index) {
		if ((values == null) || (index < 0) || (index >= values.length)) {
			return Double.NaN;
		}
		return values[index];
	}

	private int[] buildActivePoseParameters() {
		final ArrayList<Integer> active = new ArrayList<Integer>();
		for (final int dpIndex : SCENE_POSE_PARAM_ORDER) {
			final boolean selected = (paramSelect == null) ||
					((dpIndex >= 0) && (dpIndex < paramSelect.length) && paramSelect[dpIndex]);
			if (selected) {
				active.add(dpIndex);
			}
		}
		final int[] out = new int[active.size()];
		for (int i = 0; i < out.length; i++) {
			out[i] = active.get(i);
		}
		return out;
	}

	private static double[] normalizeWeights(final double[] rawWeights) {
		if ((rawWeights == null) || (rawWeights.length == 0)) {
			return new double[0];
		}
		double sum = 0.0;
		for (int i = 0; i < rawWeights.length; i++) {
			final double w = rawWeights[i];
			if (Double.isFinite(w) && (w > 0.0)) {
				sum += w;
			}
		}
		final double[] out = rawWeights.clone();
		if (sum <= 0.0) {
			Arrays.fill(out, 1.0 / out.length);
			return out;
		}
		for (int i = 0; i < out.length; i++) {
			final double w = out[i];
			out[i] = (Double.isFinite(w) && (w > 0.0)) ? (w / sum) : 0.0;
		}
		return out;
	}

	private void buildStateVector() {
		classicState.activePoseParams = buildActivePoseParameters();
		if (scenesXyzatr == null) {
			classicState.vectorSceneIndices = new int[0];
			classicState.vectorParamIndices = new int[0];
			classicState.parameterVector = new double[0];
			classicState.parameterInitial = new double[0];
			classicState.parameterPull = new double[0];
			classicState.yVector = new double[0];
			classicState.weights = new double[0];
			classicState.clearLinearization();
			return;
		}
		final ArrayList<Integer> vectorSceneIndices = new ArrayList<Integer>();
		final ArrayList<Integer> vectorParamIndices = new ArrayList<Integer>();

		for (int scene = earliestScene; scene <= lastScene; scene++) {
			if (scene == centerIndex) {
				continue;
			}
			if ((scene < 0) || (scene >= scenesXyzatr.length) || !isPoseValid(scenesXyzatr[scene])) {
				continue;
			}
			for (final int dpIndex : classicState.activePoseParams) {
				vectorSceneIndices.add(scene);
				vectorParamIndices.add(dpIndex);
			}
		}

		final int n = vectorSceneIndices.size();
		classicState.vectorSceneIndices = new int[n];
		classicState.vectorParamIndices = new int[n];
		classicState.parameterVector = new double[n];
		classicState.parameterInitial = new double[n];
		classicState.parameterPull = new double[n];
		classicState.yVector = new double[n];
		classicState.weights = new double[n];

		for (int i = 0; i < n; i++) {
			final int scene = vectorSceneIndices.get(i);
			final int dpIndex = vectorParamIndices.get(i);
			classicState.vectorSceneIndices[i] = scene;
			classicState.vectorParamIndices[i] = dpIndex;

			final double current = normalizeValue(getScenePoseParameter(scenesXyzatr[scene], dpIndex));
			double pull = current;
			if ((scenesXyzatrPull != null) &&
					(scene >= 0) && (scene < scenesXyzatrPull.length) &&
					isPoseValid(scenesXyzatrPull[scene])) {
				pull = normalizeValue(getScenePoseParameter(scenesXyzatrPull[scene], dpIndex));
			}
			final double regWeight = getArrayValue(paramRegweights, dpIndex);
			final double lpfWeight = getArrayValue(paramLpf, dpIndex);
			final double rawWeight = (regWeight > 0.0) ? regWeight : ((lpfWeight > 0.0) ? lpfWeight : 1.0);

			classicState.parameterVector[i] = current;
			classicState.parameterInitial[i] = current;
			classicState.parameterPull[i] = pull;
			classicState.yVector[i] = 0.0;
			classicState.weights[i] = rawWeight;
		}
		classicState.weights = normalizeWeights(classicState.weights);
		classicState.clearLinearization();
	}

	private void captureCurrentVectorFromScenes() {
		if (scenesXyzatr == null) {
			return;
		}
		for (int i = 0; i < classicState.parameterVector.length; i++) {
			final int scene = classicState.vectorSceneIndices[i];
			final int dpIndex = classicState.vectorParamIndices[i];
			if ((scene >= 0) && (scene < scenesXyzatr.length)) {
				classicState.parameterVector[i] = normalizeValue(getScenePoseParameter(scenesXyzatr[scene], dpIndex));
			}
		}
	}

	private void prepareLMA(final int debugLevel) {
		buildStateVector();
		if (debugLevel > -4) {
			System.out.println(
					"IntersceneGlobalLmaRefine: prepareLMA() classic scaffold active; " +
					"unknowns=" + classicState.parameterVector.length +
					", activePoseParams=" + Arrays.toString(classicState.activePoseParams) +
					", center=" + centerIndex +
					", range=[" + earliestScene + "," + lastScene + "]");
		}
	}

	private double[] getFxDerivs(
			final double[] vector,
			final double[][] jt,
			final int debugLevel) {
		final int n = vector.length;
		final double[] fx = new double[n];
		if (jt != null) {
			for (int p = 0; p < n; p++) {
				if ((jt[p] == null) || (jt[p].length != n)) {
					jt[p] = new double[n];
				}
				Arrays.fill(jt[p], 0.0);
			}
		}
		for (int i = 0; i < n; i++) {
			final double pull = classicState.parameterPull[i];
			fx[i] = vector[i] - pull;
			if (jt != null) {
				jt[i][i] = 1.0;
			}
		}
		if ((debugLevel > 2) && (n > 0)) {
			System.out.println("IntersceneGlobalLmaRefine: getFxDerivs() placeholder residual blocks active, samples=" + n);
		}
		return fx;
	}

	private double[] getYminusFxWeighted(
			final double[] fx,
			final double[] rms,
			final boolean noNaNs) {
		final double[] weighted = new double[fx.length];
		double sum = 0.0;
		for (int i = 0; i < fx.length; i++) {
			double d = classicState.yVector[i] - fx[i];
			double wd = d * classicState.weights[i];
			if (Double.isNaN(wd)) {
				if (noNaNs) {
					if ((rms != null) && (rms.length >= 2)) {
						rms[0] = Double.NaN;
						rms[1] = Double.NaN;
					}
					return null;
				}
				d = 0.0;
				wd = 0.0;
			}
			weighted[i] = wd;
			sum += d * wd;
		}
		if ((rms != null) && (rms.length >= 2)) {
			final double v = Math.sqrt(sum);
			rms[0] = v;
			rms[1] = v;
		}
		return weighted;
	}

	private double[][] getWJtJlambda(
			final double lambda,
			final double[][] jt) {
		final int n = jt.length;
		final double[][] wjtjl = new double[n][n];
		for (int i = 0; i < n; i++) {
			for (int j = i; j < n; j++) {
				double d = 0.0;
				for (int k = 0; k < classicState.weights.length; k++) {
					d += classicState.weights[k] * jt[i][k] * jt[j][k];
				}
				wjtjl[i][j] = d;
				if (i == j) {
					wjtjl[i][i] += d * lambda;
				} else {
					wjtjl[j][i] = d;
				}
			}
		}
		return wjtjl;
	}

	private static double getMaxAbs(final double[] data) {
		double max = 0.0;
		for (int i = 0; i < data.length; i++) {
			final double a = Math.abs(data[i]);
			if (a > max) {
				max = a;
			}
		}
		return max;
	}

	@SuppressWarnings("unused")
	private boolean[] lmaStep(
			final double lambda,
			final double rmsDiffStop,
			final double deltaStop,
			final int debugLevel) {
		final int n = classicState.parameterVector.length;
		if (n == 0) {
			return new boolean[] {true, true};
		}
		if ((classicState.lastJt == null) || (classicState.lastJt.length != n)) {
			classicState.lastJt = new double[n][];
		}
		if (classicState.lastRms == null) {
			classicState.lastRms = new double[2];
			final double[] fx0 = getFxDerivs(
					classicState.parameterVector,
					classicState.lastJt,
					debugLevel);
			classicState.lastYminusFx = getYminusFxWeighted(
					fx0,
					classicState.lastRms,
					true);
			if (classicState.lastYminusFx == null) {
				return null;
			}
			classicState.initialRms = classicState.lastRms.clone();
		}

		final Matrix yMinusFxWeighted = new Matrix(classicState.lastYminusFx, classicState.lastYminusFx.length);
		final Matrix wjtjLambda = new Matrix(getWJtJlambda(lambda, classicState.lastJt));
		final Matrix jty = (new Matrix(classicState.lastJt)).times(yMinusFxWeighted);
		final Matrix deltaVector;
		try {
			deltaVector = wjtjLambda.inverse().times(jty);
		} catch (RuntimeException ex) {
			return new boolean[] {false, true};
		}

		final double[] delta = deltaVector.getColumnPackedCopy();
		final double maxDelta = getMaxAbs(delta);
		final double[] oldVector = classicState.parameterVector.clone();
		final double[] oldRms = classicState.lastRms.clone();
		final double[] newVector = oldVector.clone();
		for (int i = 0; i < n; i++) {
			newVector[i] += delta[i];
		}

		final double[] fx = getFxDerivs(
				newVector,
				classicState.lastJt,
				debugLevel);
		final double[] rms = new double[2];
		final double[] yMinusFxNew = getYminusFxWeighted(
				fx,
				rms,
				true);
		if ((yMinusFxNew != null) && (rms[0] < oldRms[0])) {
			classicState.parameterVector = newVector;
			classicState.lastRms = rms;
			classicState.lastYminusFx = yMinusFxNew;
			final boolean convergedByRms = rms[0] >= (oldRms[0] * (1.0 - rmsDiffStop));
			final boolean convergedByDelta = maxDelta <= deltaStop;
			return new boolean[] {true, convergedByRms || convergedByDelta};
		}

		final double[] fxRestore = getFxDerivs(
				oldVector,
				classicState.lastJt,
				debugLevel);
		classicState.lastYminusFx = getYminusFxWeighted(
				fxRestore,
				classicState.lastRms,
				true);
		classicState.parameterVector = oldVector;
		return new boolean[] {false, false};
	}

	private IntersceneGlobalRefine.Result runLma(final int debugLevel) {
		if (debugLevel > -4) {
			System.out.println(
					"IntersceneGlobalLmaRefine: runLma() phase-1 delegating numeric solve to IntersceneGlobalRefine");
		}
		final IntersceneGlobalRefine.Result result = IntersceneGlobalRefine.refineAllToReference(
				cltParameters,
				quadCLTs,
				centerCLT,
				centerIndex,
				earliestScene,
				lastScene,
				scenesXyzatr,
				scenesXyzatrPull,
				paramSelect,
				paramRegweights,
				paramLpf,
				centerDisparity,
				reliableRef,
				disableErs,
				mbMaxGain,
				options,
				debugLevel);
		captureCurrentVectorFromScenes();
		return result;
	}
}
