package com.elphel.imagej.tensorflow;
/**
 * Copyright (C) 2018 Elphel, Inc.
 * SPDX-License-Identifier: GPL-3.0-or-later
 */

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.util.ArrayList;

import org.apache.ant.compress.taskdefs.Unzip;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

import com.elphel.imagej.common.ShowDoubleFloatArrays;
import com.elphel.imagej.tileprocessor.ImageDtt;

import ij.IJ;
import ij.ImagePlus;
import ij.ImageStack;
import ij.WindowManager;

/**

Note 0: articles & examples:

- https://divis.io/2017/11/enterprise-tensorflow-1/
- https://divis.io/2018/01/enterprise-tensorflow-code-examples/
- https://divis.io/2018/01/enterprise-tensorflow-2-saving-a-trained-model/
- https://divis.io/2018/01/enterprise-tensorflow-3-loading-a-savedmodel-in-java/
- https://divis.io/2018/01/enterprise-tensorflow-4-executing-a-tensorflow-session-in-java/
- https://www.programcreek.com/java-api-examples/?api=org.tensorflow.SavedModelBundle
- https://github.com/imagej/imagej-tensorflow
- simple: http://www.riptutorial.com/tensorflow/example/32154/load-and-use-the-model-in-java-

Note 1: How to feed:
 a. https://divis.io/2018/01/enterprise-tensorflow-4-executing-a-tensorflow-session-in-java/
 b. https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/examples/LabelImage.java



Note 2: https://divis.io/2018/01/enterprise-tensorflow-4-executing-a-tensorflow-session-in-java/

...
Two types of objects will need manual closing for proper resource handling:
Sessions and Tensors
...



*/

public class TensorflowInferModel
{
	
    final static String TRAINED_MODEL_URL = "https://community.elphel.com/files/quad-stereo/ml/trained_model_v1.0.zip";
	
    final static String TRAINED_MODEL = "trained_model"; // /home/oleg/GIT/python3-imagej-tiff/data_sets/tf_data_5x5_main_13_heur/exportdir";
    final static String SERVING = "serve";
    
    final int         tilesX, tilesY, num_tiles, num_layers;
    final int         corr_side, corr_side2;
//    final long []     shape_corr2d;
//    final long []     shape_target_disparity;
//    final long []     shape_tiles_stage1;
//    final long []     shape_tiles_stage2;
    final FloatBuffer fb_corr2d;
    final FloatBuffer fb_target_disparity;
    final IntBuffer   fb_tiles_stage1;
    final IntBuffer   fb_tiles_stage2;
    final FloatBuffer fb_predicted;

/*
long[] shape = new long[] {batch, imageSize};
 */
//    public
    final SavedModelBundle bundle;

    // utils: download url
    private static Path download(String sourceURL, String targetDirectory) throws IOException
    {
        URL url = new URL(sourceURL);
        String fileName = sourceURL.substring(sourceURL.lastIndexOf('/') + 1, sourceURL.length());
        Path targetPath = new File(targetDirectory + File.separator + fileName).toPath();
        Files.copy(url.openStream(), targetPath, StandardCopyOption.REPLACE_EXISTING);

        return targetPath;
    }
    
    // utils: unpack zip to dir
    private static boolean download_and_unpack(String sourceURL, String targetDirectory) {
    	
    	Path zipped_model = null;
    	
    	System.out.println("Downloading "+sourceURL+". Please, wait...");
    	try {
			zipped_model = download(sourceURL, targetDirectory);
    	} catch (IOException e){
    		e.printStackTrace();
    		return false;
    	}
    	
		Unzip unzipper = new Unzip();
		unzipper.setSrc(zipped_model.toFile());
		unzipper.setDest(new File(targetDirectory));
		unzipper.execute();
		
		System.out.println(unzipper.getLocation());
    	
    	return true;
    }
    
    public TensorflowInferModel(int tilesX, int tilesY, int corr_side, int num_layers)
    {
    	this.tilesX = tilesX;
    	this.tilesY = tilesY;
    	this.num_tiles = tilesX*tilesY;
    	this.num_layers = num_layers;
    	this.corr_side = corr_side;
    	this.corr_side2 = corr_side * corr_side;
    	// allocate buffers to be used for tensors
    	this.fb_corr2d =           FloatBuffer.allocate(num_tiles * corr_side2*num_layers);
    	this.fb_target_disparity = FloatBuffer.allocate(num_tiles );
    	this.fb_tiles_stage1 =     IntBuffer.allocate(num_tiles );
    	this.fb_tiles_stage2 =     IntBuffer.allocate(num_tiles );
    	this.fb_predicted =        FloatBuffer.allocate(num_tiles );

    	//String resourceDir = System.getProperty("user.dir")+"/src/main/resources";
    	// ./target/classes/
    	String resourceDir = getClass().getClassLoader().getResource("").getFile();
    	String abs_model_path = null;
    	
    	try {
    		abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile();
    	} catch (java.lang.NullPointerException e) {
    		//e.printStackTrace();
    		download_and_unpack(TRAINED_MODEL_URL, resourceDir);
    		// re-read
    		abs_model_path = getClass().getClassLoader().getResource(TRAINED_MODEL).getFile();
    		System.out.println("New downloaded path: "+abs_model_path);
    	}
    	
        System.out.println("TensorflowInferModel model path: "+abs_model_path);
        
        // this will load graph/data and open a session that does not need to be closed until the program is closed
////        bundle = null;
        bundle = SavedModelBundle.load(abs_model_path, SERVING);

//    	Operation opr = bundle.graph().operation("rv_stage1_out");
    }


	public void run_stage1(
			 FloatBuffer fb_corr2d,
			 FloatBuffer fb_target_disparity,
			 IntBuffer   fb_tiles_stage1)
	{
	    int ntiles = fb_tiles_stage1.limit(); // actual number of entries
	    long []     shape_corr2d =           new long [] {ntiles, corr_side2*num_layers};
	    long []     shape_target_disparity = new long [] {ntiles, 1};
	    long []     shape_tiles_stage1 =     new long [] {ntiles};

		final Tensor<Float>   t_corr2d =         Tensor.create(shape_corr2d,           fb_corr2d);
		final Tensor<Float>   target_disparity = Tensor.create(shape_target_disparity, fb_target_disparity);
		final Tensor<Integer> t_tiles_stage1 =   Tensor.create(shape_tiles_stage1,fb_tiles_stage1);

		// this run does not output any data, but maybe it should still be captured and disposed of?
		final Tensor<?> t_result_stage1 = bundle.session().runner()
        .feed("ph_corr2d",           t_corr2d)
        .feed("ph_target_disparity", target_disparity)
        .feed("ph_ntile",            t_tiles_stage1)
        .fetch("Disparity_net/stage1done:0")
        .run()
        .get(0);

		t_result_stage1.  close();
		t_corr2d.         close();
    	target_disparity. close();
    	t_tiles_stage1.   close();
	}

	public void run_stage2(
			 IntBuffer   fb_tiles_stage2,
			 FloatBuffer fb_disparity) {
	    int ntiles = fb_tiles_stage2.limit(); // actual number of entries
	    long []     shape_tiles_stage2 =         new long [] {ntiles};
		final Tensor<Integer> t_tiles_stage2 =   Tensor.create(shape_tiles_stage2, fb_tiles_stage2);
    	final Tensor<?> t_result_disparity = bundle.session().runner()
                .feed("ph_ntile_out",t_tiles_stage2)
                .fetch("Disparity_net/stage2_out_sparse:0")
                .run()
                .get(0);
    	fb_disparity.rewind();
    	t_result_disparity.writeTo(fb_disparity);

    	t_result_disparity. close();
    	t_tiles_stage2.     close();
	}


    //**************************
    // helper class to prepare TileProcessor task
	class TfInTile{
		float [] corr2d;
		float target_disparity;
		float gt_disparity;
		float gt_strength;
		int   tile;
		public void setFromCorr2d(
				float [][] data,
				int  tileX,
				int  tileY,
				int  corr_side,
				int  tilesX
				) {
			int corr_side2 = corr_side*corr_side;
			int layers = data.length -1;
			this.corr2d = new float [layers * corr_side2];
			float [] other = new float [corr_side2];

			int width = tilesX * corr_side;
			int tl = tileY * corr_side * width + tileX * corr_side; // index to the
			for (int nl = 0; nl <= layers; nl++) {
				if (nl < layers) {
					for (int row = 0; row < corr_side; row++) {
						System.arraycopy(data[nl], tl+width*row, corr2d, nl*corr_side2 + row*corr_side, corr_side);
					}
				} else {
					for (int row = 0; row < corr_side; row++) {
						System.arraycopy(data[layers], tl+width*row, other, row*corr_side, corr_side);
					}
					this.target_disparity = other[ImageDtt.ML_OTHER_TARGET];
					this.gt_disparity =     other[ImageDtt.ML_OTHER_GTRUTH];
					this.gt_strength =      other[ImageDtt.ML_OTHER_GTRUTH_STRENGTH];
				}
			}
			if (Float.isNaN(this.target_disparity)) {
				corr2d = new float[corr2d.length]; // zero them all
			}
			this.tile = tileY*tilesX + tileX;

		}
//System.arraycopy(sym_conv, 0, tile_in, 0, n2*n2)

	}
	public int test_tensorflow(
			boolean keep_empty) {

		int dbgX = 162;
		int dbgY = 121;
//		int dbgT = tilesX ^ dbgY + dbgX;

		int corr_side = 9;
		String [] slices = {"hor-pairs","vert-pairs","diagm-pair","vert-pairs","other"};
    	ImagePlus imp_src = WindowManager.getCurrentImage();
    	if (imp_src==null){
    		IJ.showMessage("Error","2D Correlation image stack required");
    		return -1;
    	}
    	ImageStack corr_stack = imp_src.getStack();
    	String [] labels =  corr_stack.getSliceLabels();
//    	for (int ii = 0; ii < labels.length; ii++) {
//    		System.out.println(ii+": "+labels[ii]);
//    	}
    	int tilesX= corr_stack.getWidth()  /  corr_side;
    	int tilesY= corr_stack.getHeight() /  corr_side;

    	float [][] corr_data = new float [slices.length][];
    	for (int nslice = 0; nslice < slices.length; nslice++ ) {
    		int ns = -1;
    		for (int i = 0; i < labels.length; i++) {
    			if (slices[nslice].equals(labels[i])) {
    				ns = i;
    				break;
    			}
    		}
    		if (ns < 0) {
    			System.out.println("Slice "+slices[nslice]+" is not found in the image");
    			return -1;
    		} else {
    			corr_data[nslice] = (float[]) corr_stack.getPixels(ns + 1) ;
    		}
    	}
    	ArrayList<TfInTile> tf_tiles = new ArrayList<TfInTile>();
    	for (int tileY = 0; tileY < tilesY; tileY++) {
    		for (int tileX = 0; tileX < tilesX; tileX++) {
    			if ((tileY==dbgY) && (tileX==dbgX)) {
    				System.out.println("tileY = "+tileY+", tileX = "+tileX+" tile = "+(tileY*tilesX + tileX));
    			}
    			TfInTile tf_tile = new TfInTile();
    			tf_tile.setFromCorr2d(
    					corr_data, // float [][] data,
    					tileX,     // int  tileX,
    					tileY,     // int  tileY,
    					corr_side, // int  corr_side,
    					tilesX);   // int  tilesX
    			if (keep_empty || !Float.isNaN(tf_tile.target_disparity)) {
//    				if (Float.isNaN(tf_tile.target_disparity)) {
//    					tf_tile.target_disparity = 0.0f;
//    					tf_tile.gt_strength =      0.0f;
//    				}
    				tf_tiles.add(tf_tile);
    			}
    		}
    	}

    	// sets the limit to the capacity and the position to zero.
    	fb_corr2d.clear();
    	fb_target_disparity.clear();
    	fb_tiles_stage1.clear();
    	fb_tiles_stage2.clear();
    	for (int i = 0; i < tf_tiles.size(); i++) {
    		TfInTile tf_tile = tf_tiles.get(i);
    		fb_corr2d.put          (tf_tile.corr2d);
    		float td = tf_tile.target_disparity;
    		if (Float.isNaN(td)) td = 0.0f;
    		fb_target_disparity.put(i, td);
    		fb_tiles_stage1.put(tf_tile.tile);
    		fb_tiles_stage2.put(tf_tile.tile);
    	}
//    	fb_target_disparity.limit(tf_tiles.size()); // put float absolute does not movew the pointer
    	fb_target_disparity.position(tf_tiles.size()); // put float absolute does not movew the pointer
    	fb_tiles_stage1.    position(tf_tiles.size()); // put float absolute does not movew the pointer
    	fb_tiles_stage2.    position(tf_tiles.size()); // put float absolute does not movew the pointer
    	//sets the limit to the current position and then sets the position to zero.
    	fb_corr2d.          flip();
    	fb_target_disparity.flip();
    	fb_tiles_stage1.    flip();
    	fb_tiles_stage2.    flip();
    	String [] titles = {"predicted", "target", "gt_disparity", "gt_strength", "nn_out", "nn_error","abs_err","abs_heur","clean_nn"};
    	fb_predicted.rewind(); // not needed for absolute get();
    	float [][] result = new float [titles.length][tilesX*tilesY];
/*
    	for (int i = 0; i < tf_tiles.size(); i++) {
    		TfInTile tf_tile = tf_tiles.get(i);
//    		result[0][i] = tf_tile.target_disparity + fb_predicted.get(i);
    		result[1][i] = tf_tile.target_disparity;
    		result[2][i] = tf_tile.gt_disparity;
    		result[3][i] = tf_tile.gt_strength;
    	}

		(new showDoubleFloatArrays()).showArrays(
				result,
				tilesX,
				tilesY,
				true,
				"NN_disparity-pre",
				titles);
*/


//if (corr_side > 0) return 0; // always
    	run_stage1(
    			fb_corr2d,           // FloatBuffer fb_corr2d,
    			fb_target_disparity, // FloatBuffer fb_target_disparity,
    			fb_tiles_stage1);    // IntBuffer   fb_tiles_stage1);
    	run_stage2(
    			fb_tiles_stage2, // IntBuffer   fb_tiles_stage2,
    			fb_predicted); // FloatBuffer fb_disparity)


    	if (!keep_empty) {
    		for (int i =0; i < result[0].length; i++) {
    			result[0][i]=Float.NaN;
    			result[1][i]=Float.NaN;
    			result[2][i]=Float.NaN;
    		}
    	}

    	fb_predicted.rewind(); // not needed for absolute get();
    	for (int i = 0; i < tf_tiles.size(); i++) {
    		TfInTile tf_tile = tf_tiles.get(i);
    		result[0][i] = tf_tile.target_disparity + fb_predicted.get(i);
    		result[1][i] = tf_tile.target_disparity;
    		result[2][i] = tf_tile.gt_disparity;
    		result[3][i] = tf_tile.gt_strength;
    		result[4][i] = fb_predicted.get(i);
    		result[5][i] = (result[3][i] > 0)?(result[0][i]-result[2][i]):Float.NaN;
    		result[6][i] = (result[3][i] > 0)?Math.abs(result[5][i]):Float.NaN;
    		result[7][i] = (result[3][i] > 0)?Math.abs(result[1][i] - result[2][i]):Float.NaN;
    		result[8][i] = (result[3][i] > 0)?result[0][i]:Float.NaN;

    	}
		(new ShowDoubleFloatArrays()).showArrays(
				result,
				tilesX,
				tilesY,
				true,
				"NN_disparity",
				titles);

    	return 0;
	}


}