Commit f2ca38ae authored by Oleg Dzhimiev's avatar Oleg Dzhimiev

updated to 1.15.2

parent 963e5a90
...@@ -8,12 +8,12 @@ ...@@ -8,12 +8,12 @@
<dependency> <dependency>
<groupId>org.tensorflow</groupId> <groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId> <artifactId>libtensorflow</artifactId>
<version>1.15.0</version> <version>1.15.2</version>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.tensorflow</groupId> <groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId> <artifactId>libtensorflow_jni_gpu</artifactId>
<version>1.15.0</version> <version>1.15.2</version>
</dependency> </dependency>
<!-- https://mvnrepository.com/artifact/org.tensorflow/proto --> <!-- https://mvnrepository.com/artifact/org.tensorflow/proto -->
<dependency> <dependency>
......
...@@ -155,9 +155,9 @@ public class tfhello{ ...@@ -155,9 +155,9 @@ public class tfhello{
System.out.println(TensorFlow.version()); System.out.println(TensorFlow.version());
System.out.println("Test 3 end\n"); System.out.println("Test 3 end\n");
System.out.println("Test 4 start\n - Test simple custom JNI function added to TF"); //System.out.println("Test 4 start\n - Test simple custom JNI function added to TF");
System.out.println(TensorFlow.elphelVersion()); //System.out.println(TensorFlow.elphelVersion());
System.out.println("Test 4 end\n"); //System.out.println("Test 4 end\n");
//callableOpts.newBuilder().putFeedDevices(key, value); //callableOpts.newBuilder().putFeedDevices(key, value);
...@@ -220,14 +220,14 @@ public class tfhello{ ...@@ -220,14 +220,14 @@ public class tfhello{
// natively got GPU device name to insert into options // natively got GPU device name to insert into options
// it's the same all the time // it's the same all the time
String gpuDeviceName = s.elphelGPUDeviceName(); String gpuDeviceName = s.GPUDeviceName();
// GPU allocation: dims must be power of 2? // GPU allocation: dims must be power of 2?
Tensor t3 = Tensor.elphelCreateGPUTensor(new long[]{256},DataType.FLOAT); Tensor t3 = Tensor.createGPU(new long[]{256},DataType.FLOAT);
//System.out.println(t2.nativeRef); //System.out.println(t2.nativeRef);
// Let's check what happended // Let's check what happended
long t3_gpuptr = t3.elphel_GetGPUTensorPointer(); long t3_gpuptr = t3.GPUPointer();
// Print address // Print address
//System.out.println("Pointer address: "+String.format("0x%08x", t3_gpuptr)); //System.out.println("Pointer address: "+String.format("0x%08x", t3_gpuptr));
...@@ -276,8 +276,8 @@ public class tfhello{ ...@@ -276,8 +276,8 @@ public class tfhello{
System.out.println(callableOpts); System.out.println(callableOpts);
// callable handle // callable handle
long feed_gpu_fetch_cpu = s.MakeCallable(callableOpts.toByteArray()); long feed_gpu_fetch_cpu = s.makeCallable(callableOpts.toByteArray());
Tensor<?> t3out = s.runner().fetch("array_tensor_out").feed("array_tensor_in",t3).runElphelCallable(feed_gpu_fetch_cpu).get(0); Tensor<?> t3out = s.runner().fetch("array_tensor_out").feed("array_tensor_in",t3).runCallable(feed_gpu_fetch_cpu).get(0);
System.out.println(t3); System.out.println(t3);
System.out.println(t3out); System.out.println(t3out);
......
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import java.util.ArrayList;
import java.util.List;
/**
* Driver for {@link Graph} execution.
*
* <p>A {@code Session} instance encapsulates the environment in which {@link Operation}s in a
* {@link Graph} are executed to compute {@link Tensor Tensors}. For example:
*
* <pre>{@code
* // Let's say graph is an instance of the Graph class
* // for the computation y = 3 * x
*
* try (Session s = new Session(graph)) {
* try (Tensor x = Tensor.create(2.0f);
* Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
* System.out.println(y.floatValue()); // Will print 6.0f
* }
* try (Tensor x = Tensor.create(1.1f);
* Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
* System.out.println(y.floatValue()); // Will print 3.3f
* }
* }
* }</pre>
*
* <p><b>WARNING:</b>A {@code Session} owns resources that <b>must</b> be explicitly freed by
* invoking {@link #close()}.
*
* <p>Instances of a Session are thread-safe.
*/
public final class Session implements AutoCloseable {
/** Construct a new session with the associated {@link Graph}. */
public Session(Graph g) {
this(g, null);
}
/**
* Construct a new session with the associated {@link Graph} and configuration options.
*
* @param g The {@link Graph} the created Session will operate on.
* @param config Configuration parameters for the session specified as a serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">ConfigProto</a>
* protocol buffer.
* @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto
* protocol buffer.
*/
public Session(Graph g, byte[] config) {
graph = g;
Graph.Reference r = g.ref();
try {
nativeHandle =
(config == null) ? allocate(r.nativeHandle()) : allocate2(r.nativeHandle(), null, config);
graphRef = g.ref();
} finally {
r.close();
}
}
/** Wrap an existing session with the associated {@link Graph}. */
Session(Graph g, long nativeHandle) {
graph = g;
this.nativeHandle = nativeHandle;
graphRef = g.ref();
}
public String elphelGPUDeviceName(){
return elphelGetGPUDeviceName(this.nativeHandle);
}
private native String elphelGetGPUDeviceName(long handle);
public long MakeCallable(byte[] config){
return elphelMakeCallable(this.nativeHandle, config);
}
private native long elphelMakeCallable(long nativeHandle, byte[] config);
/**
* Release resources associated with the Session.
*
* <p>Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session
* is not usable after close returns.
*/
@Override
public void close() {
graphRef.close();
synchronized (nativeHandleLock) {
if (nativeHandle == 0) {
return;
}
while (numActiveRuns > 0) {
try {
nativeHandleLock.wait();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
// Possible leak of the Session and Graph in this case?
return;
}
}
delete(nativeHandle);
nativeHandle = 0;
}
}
/**
* Run {@link Operation}s and evaluate {@link Tensor Tensors}.
*
* <p>A Runner runs the necessary graph fragments to execute every {@link Operation} required to
* evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows
* callers to override the value of {@link Tensor Tensors} in the graph by substituting the
* provided {@link Tensor Tensors} for the outputs of the operations provided to {@link
* #feed(String,int,Tensor)}.
*/
public final class Runner {
/**
* Avoid evaluating {@code operation} and substitute {@code t} for the value it produces.
*
* @param operation Is either the string name of the operation, in which case this method is a
* shorthand for {@code feed(operation, 0)}, or it is a string of the form
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
* feed(operation_name, output_index)}. These colon-separated names are commonly used in the
* {@code SignatureDef} protocol buffer messages that are included in {@link
* SavedModelBundle#metaGraphDef()}.
*/
public Runner feed(String operation, Tensor<?> t) {
//debug
System.out.println("Adding feed to operation: "+operation);
return feed(parseOutput(operation), t);
}
/**
* Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t}
* for the value it produces.
*
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one {@code t} is being provided for.
*/
public Runner feed(String operation, int index, Tensor<?> t) {
Operation op = operationByName(operation);
if (op != null) {
inputs.add(op.output(index));
inputTensors.add(t);
}
return this;
}
/**
* Use {@code t} instead of the Tensor referred to by executing the operation referred to by
* {@code operand}.
*/
public Runner feed(Operand<?> operand, Tensor<?> t) {
inputs.add(operand.asOutput());
inputTensors.add(t);
return this;
}
/**
* Feed for RunCallable - just a tensor
*/
public Runner feed(Tensor<?> t) {
inputTensors.add(t);
return this;
}
/**
* Make {@link #run()} return the output of {@code operation}.
*
* @param operation Is either the string name of the operation, in which case this method is a
* shorthand for {@code fetch(operation, 0)}, or it is a string of the form
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
* fetch(operation_name, output_index)}. These colon-separated names are commonly used in
* the {@code SignatureDef} protocol buffer messages that are included in {@link
* SavedModelBundle#metaGraphDef()}.
*/
public Runner fetch(String operation) {
return fetch(parseOutput(operation));
}
/**
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
*
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one to return.
*/
public Runner fetch(String operation, int index) {
Operation op = operationByName(operation);
if (op != null) {
outputs.add(op.output(index));
}
return this;
}
/**
* Makes {@link #run()} return the Tensor referred to by {@code output}.
*/
public Runner fetch(Output<?> output) {
outputs.add(output);
return this;
}
/**
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
*/
public Runner fetch(Operand<?> operand) {
return fetch(operand.asOutput());
}
/**
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
* Tensors}.
*/
public Runner addTarget(String operation) {
GraphOperation op = operationByName(operation);
if (op != null) {
targets.add(op);
}
return this;
}
/**
* Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor
* Tensors}.
*
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
*/
public Runner addTarget(Operation operation) {
if (!(operation instanceof GraphOperation)) {
throw new IllegalArgumentException(
"Operation of type "
+ operation.getClass().getName()
+ " is not supported in graph sessions");
}
targets.add((GraphOperation) operation);
return this;
}
/**
* Make {@link #run} execute {@code operand}, but not return any evaluated {@link Tensor
* Tensors}.
*/
public Runner addTarget(Operand<?> operand) {
return addTarget(operand.asOutput().op());
}
/**
* (Experimental method): set options (typically for debugging) for this run.
*
* <p>The options are presented as a serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions
* protocol buffer</a>.
*
* <p>The org.tensorflow package is free of any protocol buffer dependencies in order to remain
* friendly to resource constrained systems (where something like <a
* href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
* be more appropriate). A cost of that is this lack of type-safety in this API function. This
* choice is under review and this function may be replaced by more type-safe equivalents at any
* time.
*/
public Runner setOptions(byte[] options) {
this.runOptions = options;
return this;
}
/**
* Execute the graph fragments necessary to compute all requested fetches.
*
* <p><b>WARNING:</b> The caller assumes ownership of all returned {@link Tensor Tensors}, i.e.,
* the caller must call {@link Tensor#close} on all elements of the returned list to free up
* resources.
*
* <p>TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it
* easier for the caller to cleanup (perhaps returning something like AutoCloseableList in
* SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a
* {@code Map<Output, Tensor>}?
*
* <p>TODO(andrewmyers): It would also be good if whatever is returned here made it easier to
* extract output tensors in a type-safe way.
*/
public List<Tensor<?>> run() {
return runHelper(false).outputs;
}
/**
* Execute graph fragments to compute requested fetches and return metadata about the run.
*
* <p>This is exactly like {@link #run()}, but in addition to the requested Tensors, also
* returns metadata about the graph execution in the form of a serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
* protocol buffer</a>.
*/
public Run runAndFetchMetadata() {
return runHelper(true);
}
public List<Tensor<?>> runElphelCallable(long handle) {
return runElphelCallableHelper(handle).outputs;
}
// whatever
private Run runElphelCallableHelper(long handle) {
long[] inputTensorHandles = new long[inputTensors.size()];
long[] outputTensorHandles = new long[outputs.size()];
System.out.println("Number of input handles: "+inputTensors.size());
System.out.println("Number of output handles: "+outputs.size());
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
// validity of the Graph and graphRef ensures that.
int idx = 0;
for (Tensor<?> t : inputTensors) {
inputTensorHandles[idx++] = t.getNativeHandle();
}
Reference runRef = new Reference();
byte[] metadata = null;
try {
System.out.println("About to run RunCallable\n");
metadata =
Session.elphelRunCallable(
nativeHandle,
handle,
inputTensorHandles,
outputTensorHandles);
System.out.println("Ready to process output\n");
} finally {
runRef.close();
}
System.out.println("Processing output\n");
// test something here
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
for (long h : outputTensorHandles) {
try {
outputs.add(Tensor.fromHandle(h));
} catch (Exception e) {
for (Tensor<?> t : outputs) {
t.close();
}
outputs.clear();
throw e;
}
}
Run ret = new Run();
ret.outputs = outputs;
ret.metadata = metadata;
return ret;
}
private Run runHelper(boolean wantMetadata) {
long[] inputTensorHandles = new long[inputTensors.size()];
long[] inputOpHandles = new long[inputs.size()];
int[] inputOpIndices = new int[inputs.size()];
long[] outputOpHandles = new long[outputs.size()];
int[] outputOpIndices = new int[outputs.size()];
long[] targetOpHandles = new long[targets.size()];
long[] outputTensorHandles = new long[outputs.size()];
// It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the
// validity of the Graph and graphRef ensures that.
int idx = 0;
for (Tensor<?> t : inputTensors) {
inputTensorHandles[idx++] = t.getNativeHandle();
}
idx = 0;
for (Output<?> o : inputs) {
inputOpHandles[idx] = o.getUnsafeNativeHandle();
inputOpIndices[idx] = o.index();
idx++;
}
idx = 0;
for (Output<?> o : outputs) {
outputOpHandles[idx] = o.getUnsafeNativeHandle();
outputOpIndices[idx] = o.index();
idx++;
}
idx = 0;
for (GraphOperation op : targets) {
targetOpHandles[idx++] = op.getUnsafeNativeHandle();
}
Reference runRef = new Reference();
byte[] metadata = null;
try {
metadata =
Session.run(
nativeHandle,
runOptions,
inputTensorHandles,
inputOpHandles,
inputOpIndices,
outputOpHandles,
outputOpIndices,
targetOpHandles,
wantMetadata,
outputTensorHandles);
} finally {
runRef.close();
}
List<Tensor<?>> outputs = new ArrayList<Tensor<?>>();
for (long h : outputTensorHandles) {
try {
outputs.add(Tensor.fromHandle(h));
} catch (Exception e) {
for (Tensor<?> t : outputs) {
t.close();
}
outputs.clear();
throw e;
}
}
Run ret = new Run();
ret.outputs = outputs;
ret.metadata = metadata;
return ret;
}
private class Reference implements AutoCloseable {
public Reference() {
synchronized (nativeHandleLock) {
if (nativeHandle == 0) {
throw new IllegalStateException("run() cannot be called on the Session after close()");
}
++numActiveRuns;
}
}
@Override
public void close() {
synchronized (nativeHandleLock) {
if (nativeHandle == 0) {
return;
}
if (--numActiveRuns == 0) {
nativeHandleLock.notifyAll();
}
}
}
}
private GraphOperation operationByName(String opName) {
GraphOperation op = graph.operation(opName);
if (op == null) {
throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph");
}
return op;
}
@SuppressWarnings("rawtypes")
private Output<?> parseOutput(String opName) {
int colon = opName.lastIndexOf(':');
if (colon == -1 || colon == opName.length() - 1) {
return new Output(operationByName(opName), 0);
}
try {
String op = opName.substring(0, colon);
int index = Integer.parseInt(opName.substring(colon + 1));
return new Output(operationByName(op), index);
} catch (NumberFormatException e) {
return new Output(operationByName(opName), 0);
}
}
private ArrayList<Output<?>> inputs = new ArrayList<Output<?>>();
private ArrayList<Tensor<?>> inputTensors = new ArrayList<Tensor<?>>();
private ArrayList<Output<?>> outputs = new ArrayList<Output<?>>();
private ArrayList<GraphOperation> targets = new ArrayList<GraphOperation>();
private byte[] runOptions = null;
}
/** Create a Runner to execute graph operations and evaluate Tensors. */
public Runner runner() {
return new Runner();
}
/**
* Output tensors and metadata obtained when executing a session.
*
* <p>See {@link Runner#runAndFetchMetadata()}
*/
public static final class Run {
/** Tensors from requested fetches. */
public List<Tensor<?>> outputs;
/**
* (Experimental): Metadata about the run.
*
* <p>A serialized <a
* href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata
* protocol buffer</a>. The org.tensorflow package is free of any protocol buffer dependencies
* in order to remain friendly to resource constrained systems (where something like <a
* href="https://github.com/google/protobuf/tree/master/javanano#nano-version">nanoproto</a> may
* be more appropriate). A cost of that is this opaque blob. This choice is under review and
* this field may be replaced by more type-safe equivalents at any time.
*/
public byte[] metadata;
}
private final Graph graph;
private final Graph.Reference graphRef;
private final Object nativeHandleLock = new Object();
private long nativeHandle;
private int numActiveRuns;
// TODO(ashankar): Remove after TensorFlow 1.2 has been released with allocate2().
private static native long allocate(long graphHandle);
private static native long allocate2(long graphHandle, String target, byte[] config);
private static native void delete(long handle);
/**
* Execute a session.
*
* <p>The author apologizes for the ugliness of the long argument list of this method. However,
* take solace in the fact that this is a private method meant to cross the JNI boundary.
*
* @param handle to the C API TF_Session object (Session.nativeHandle)
* @param runOptions serialized representation of a RunOptions protocol buffer, or null
* @param inputOpHandles (see inputOpIndices)
* @param inputOpIndices (see inputTensorHandles)
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values
* that are being "fed" (do not need to be computed) during graph execution.
* inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the
* inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that
* inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
* @param outputOpHandles (see outputOpIndices)
* @param outputOpIndices together with outputOpHandles identifies the set of values that should
* be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is
* required that outputOpHandles.length == outputOpIndices.length.
* @param targetOpHandles is the set of Operations in the graph that are to be executed but whose
* output will not be returned
* @param wantRunMetadata indicates whether metadata about this execution should be returned.
* @param outputTensorHandles will be filled in with handles to the outputs requested. It is
* required that outputTensorHandles.length == outputOpHandles.length.
* @return if wantRunMetadata is true, serialized representation of the RunMetadata protocol
* buffer, false otherwise.
*/
private static native byte[] run(
long handle,
byte[] runOptions,
long[] inputTensorHandles,
long[] inputOpHandles,
int[] inputOpIndices,
long[] outputOpHandles,
int[] outputOpIndices,
long[] targetOpHandles,
boolean wantRunMetadata,
long[] outputTensorHandles);
/**
* Run RunCallable Callable
*/
private static native byte[] elphelRunCallable(
long sessionHandle,
long callableHandle,
long[] inputTensorHandles,
long[] outputTensorHandles);
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
import java.lang.reflect.Array;
import java.nio.Buffer;
import java.nio.BufferOverflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.util.Arrays;
import java.util.HashMap;
/**
* A statically typed multi-dimensional array whose elements are of a type described by T.
*
* <p>Instances of a Tensor are <b>not</b> thread-safe.
*
* <p><b>WARNING:</b> Resources consumed by the Tensor object <b>must</b> be explicitly freed by
* invoking the {@link #close()} method when the object is no longer needed. For example, using a
* try-with-resources block:
*
* <pre>{@code
* try (Tensor t = Tensor.create(...)) {
* doSomethingWith(t);
* }
* }</pre>
*/
public final class Tensor<T> implements AutoCloseable {
/**
* Creates a Tensor from a Java object.
*
* <p>A {@code Tensor} is a multi-dimensional array of elements of a limited set of types. Not all
* Java objects can be converted to a {@code Tensor}. In particular, the argument {@code obj} must
* be either a primitive (float, double, int, long, boolean, byte) or a multi-dimensional array of
* one of those primitives. The argument {@code type} specifies how to interpret the first
* argument as a TensorFlow type. For example:
*
* <pre>{@code
* // Valid: A 64-bit integer scalar.
* Tensor<Long> s = Tensor.create(42L, Long.class);
*
* // Valid: A 3x2 matrix of floats.
* float[][] matrix = new float[3][2];
* Tensor<Float> m = Tensor.create(matrix, Float.class);
*
* // Invalid: Will throw an IllegalArgumentException as an arbitrary Object
* // does not fit into the TensorFlow type system.
* Tensor<?> o = Tensor.create(new Object())
*
* // Invalid: Will throw an IllegalArgumentException since there are
* // a differing number of elements in each row of this 2-D array.
* int[][] twoD = new int[2][];
* twoD[0] = new int[1];
* twoD[1] = new int[2];
* Tensor<Integer> x = Tensor.create(twoD, Integer.class);
* }</pre>
*
* {@link String}-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can
* be initialized from arrays of {@code byte[]} elements. For example:
*
* <pre>{@code
* // Valid: A String tensor.
* Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);
*
* // Java Strings will need to be encoded into a byte-sequence.
* String mystring = "foo";
* Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);
*
* // Valid: Matrix of String tensors.
* // Each element might have a different length.
* byte[][][] matrix = new byte[2][2][];
* matrix[0][0] = "this".getBytes("UTF-8");
* matrix[0][1] = "is".getBytes("UTF-8");
* matrix[1][0] = "a".getBytes("UTF-8");
* matrix[1][1] = "matrix".getBytes("UTF-8");
* Tensor<String> m = Tensor.create(matrix, String.class);
* }</pre>
*
* @param obj The object to convert to a {@code Tensor<T>}. Note that whether it is compatible
* with the type T is not checked by the type system. For type-safe creation of tensors, use
* {@link Tensors}.
* @param type The class object representing the type T.
* @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
* system.
*/
@SuppressWarnings("unchecked")
public static <T> Tensor<T> create(Object obj, Class<T> type) {
DataType dtype = DataType.fromClass(type);
if (!objectCompatWithType(obj, dtype)) {
throw new IllegalArgumentException(
"DataType of object does not match T (expected "
+ dtype
+ ", got "
+ dataTypeOf(obj)
+ ")");
}
return (Tensor<T>) create(obj, dtype);
}
/**
* Creates a tensor from an object whose class is inspected to figure out what the underlying data
* type should be.
*
* @throws IllegalArgumentException if {@code obj} is not compatible with the TensorFlow type
* system.
*/
public static Tensor<?> create(Object obj) {
return create(obj, dataTypeOf(obj));
}
/**
* Create a Tensor of data type {@code dtype} from a Java object. Requires the parameter {@code T}
* to match {@code type}, but this condition is not checked.
*
* @param obj the object supplying the tensor data.
* @param dtype the data type of the tensor to create. It must be compatible with the run-time
* type of the object.
* @return the new tensor
*/
private static Tensor<?> create(Object obj, DataType dtype) {
@SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(dtype);
t.shapeCopy = new long[numDimensions(obj, dtype)];
fillShape(obj, 0, t.shapeCopy);
long nativeHandle;
if (t.dtype != DataType.STRING) {
int byteSize = elemByteSize(t.dtype) * numElements(t.shapeCopy);
nativeHandle = allocate(t.dtype.c(), t.shapeCopy, byteSize);
setValue(nativeHandle, obj);
} else if (t.shapeCopy.length != 0) {
nativeHandle = allocateNonScalarBytes(t.shapeCopy, (Object[]) obj);
} else {
nativeHandle = allocateScalarBytes((byte[]) obj);
}
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
/**
* Create a Tensor of data type in GPU
*/
public static Tensor<?> elphelCreateGPUTensor(long[] shape, DataType dtype){
@SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(dtype);
t.shapeCopy = shape;
long nativeHandle;
nativeHandle = elphelAllocateGPUTensor(t.shapeCopy,t.dtype.c());
t.nativeRef = new NativeReference(nativeHandle);
//System.out.println(t.nativeRef);
return t;
}
/**
* Create a {@link Integer} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
* 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Tensor<Integer> create(long[] shape, IntBuffer data) {
Tensor<Integer> t = allocateForBuffer(DataType.INT32, shape, data.remaining());
t.buffer().asIntBuffer().put(data);
return t;
}
/**
* Create a {@link Float} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
* 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Tensor<Float> create(long[] shape, FloatBuffer data) {
Tensor<Float> t = allocateForBuffer(DataType.FLOAT, shape, data.remaining());
t.buffer().asFloatBuffer().put(data);
return t;
}
/**
* Create a {@link Double} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
* 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Tensor<Double> create(long[] shape, DoubleBuffer data) {
Tensor<Double> t = allocateForBuffer(DataType.DOUBLE, shape, data.remaining());
t.buffer().asDoubleBuffer().put(data);
return t;
}
/**
* Create an {@link Long} Tensor with data from the given buffer.
*
* <p>Creates a Tensor with the given shape by copying elements from the buffer (starting from its
* current position) into the tensor. For example, if {@code shape = {2,3} } (which represents a
* 2x3 matrix) then the buffer must have 6 elements remaining, which will be consumed by this
* method.
*
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor shape is not compatible with the buffer
*/
public static Tensor<Long> create(long[] shape, LongBuffer data) {
Tensor<Long> t = allocateForBuffer(DataType.INT64, shape, data.remaining());
t.buffer().asLongBuffer().put(data);
return t;
}
/**
* Create a Tensor of any type with data from the given buffer.
*
* <p>Creates a Tensor with the provided shape of any type where the tensor's data has been
* encoded into {@code data} as per the specification of the TensorFlow <a
* href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C
* API</a>.
*
* @param <T> the tensor element type
* @param type the tensor element type, represented as a class object.
* @param shape the tensor shape.
* @param data a buffer containing the tensor data.
* @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the
* buffer
*/
public static <T> Tensor<T> create(Class<T> type, long[] shape, ByteBuffer data) {
@SuppressWarnings("unchecked")
Tensor<T> ret = (Tensor<T>) create(DataType.fromClass(type), shape, data);
return ret;
}
private static Tensor<?> create(DataType dtype, long[] shape, ByteBuffer data) {
int nremaining;
if (dtype != DataType.STRING) {
int elemBytes = elemByteSize(dtype);
if (data.remaining() % elemBytes != 0) {
throw new IllegalArgumentException(
String.format(
"ByteBuffer with %d bytes is not compatible with a %s Tensor (%d bytes/element)",
data.remaining(), dtype.toString(), elemBytes));
}
nremaining = data.remaining() / elemBytes;
} else {
nremaining = data.remaining();
}
Tensor<?> t = allocateForBuffer(dtype, shape, nremaining);
t.buffer().put(data);
return t;
}
/**
* Returns this Tensor object with the type {@code Tensor<U>}. This method is useful when given a
* value of type {@code Tensor<?>}.
*
* @param type any (non-null) array of the correct type.
* @throws IllegalArgumentException if the actual data type of this object does not match the type
* {@code U}.
*/
@SuppressWarnings("unchecked")
public <U> Tensor<U> expect(Class<U> type) {
DataType dt = DataType.fromClass(type);
if (!dt.equals(dtype)) {
throw new IllegalArgumentException(
"Cannot cast from tensor of " + dtype + " to tensor of " + dt);
}
return ((Tensor<U>) this);
}
// Helper function to allocate a Tensor for the create() methods that create a Tensor from
// a java.nio.Buffer.
// Requires: dataType matches T
private static <T> Tensor<T> allocateForBuffer(DataType dataType, long[] shape, int nBuffered) {
final int nflattened = numElements(shape);
int nbytes = 0;
if (dataType != DataType.STRING) {
if (nBuffered != nflattened) {
throw incompatibleBuffer(nBuffered, shape);
}
nbytes = nflattened * elemByteSize(dataType);
} else {
// DT_STRING tensor encoded in a ByteBuffer.
nbytes = nBuffered;
}
Tensor<T> t = new Tensor<T>(dataType);
t.shapeCopy = Arrays.copyOf(shape, shape.length);
long nativeHandle = allocate(t.dtype.c(), t.shapeCopy, nbytes);
t.nativeRef = new NativeReference(nativeHandle);
return t;
}
/**
* Release resources associated with the Tensor.
*
* <p><b>WARNING:</b>This must be invoked for all tensors that were not been produced by an eager
* operation or memory will be leaked.
*
* <p>The Tensor object is no longer usable after {@code close} returns.
*/
@Override
public void close() {
nativeRef.release();
}
/** Returns the {@link DataType} of elements stored in the Tensor. */
public DataType dataType() {
return dtype;
}
/**
* Returns the number of dimensions (sometimes referred to as <a
* href="https://www.tensorflow.org/resources/dims_types.html#rank">rank</a>) of the Tensor.
*
* <p>Will be 0 for a scalar, 1 for a vector, 2 for a matrix, 3 for a 3-dimensional tensor etc.
*/
public int numDimensions() {
return shapeCopy.length;
}
/** Returns the size, in bytes, of the tensor data. */
public int numBytes() {
return buffer().remaining();
}
/** Returns the number of elements in a flattened (1-D) view of the tensor. */
public int numElements() {
return numElements(shapeCopy);
}
/**
* Returns the <a href="https://www.tensorflow.org/resources/dims_types.html#shape">shape</a> of
* the Tensor, i.e., the sizes of each dimension.
*
* @return an array where the i-th element is the size of the i-th dimension of the tensor.
*/
public long[] shape() {
return shapeCopy;
}
/**
* Returns the value in a scalar {@link Float} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a float scalar.
*/
public float floatValue() {
return scalarFloat(getNativeHandle());
}
/**
* Returns the value in a scalar {@link Double} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a double scalar.
*/
public double doubleValue() {
return scalarDouble(getNativeHandle());
}
/**
* Returns the value in a scalar {@link Integer} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a int scalar.
*/
public int intValue() {
return scalarInt(getNativeHandle());
}
/**
* Returns the value in a scalar {@link Long} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a long scalar.
*/
public long longValue() {
return scalarLong(getNativeHandle());
}
/**
* Returns the value in a scalar {@link Boolean} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public boolean booleanValue() {
return scalarBoolean(getNativeHandle());
}
/**
* Returns the value in a scalar {@link String} tensor.
*
* @throws IllegalArgumentException if the Tensor does not represent a boolean scalar.
*/
public byte[] bytesValue() {
return scalarBytes(getNativeHandle());
}
/**
* Copies the contents of the tensor to {@code dst} and returns {@code dst}.
*
* <p>For non-scalar tensors, this method copies the contents of the underlying tensor to a Java
* array. For scalar tensors, use one of {@link #bytesValue()}, {@link #floatValue()}, {@link
* #doubleValue()}, {@link #intValue()}, {@link #longValue()} or {@link #booleanValue()} instead.
* The type and shape of {@code dst} must be compatible with the tensor. For example:
*
* <pre>{@code
* int matrix[2][2] = {{1,2},{3,4}};
* try(Tensor t = Tensor.create(matrix)) {
* // Succeeds and prints "3"
* int[][] copy = new int[2][2];
* System.out.println(t.copyTo(copy)[1][0]);
*
* // Throws IllegalArgumentException since the shape of dst does not match the shape of t.
* int[][] dst = new int[4][1];
* t.copyTo(dst);
* }
* }</pre>
*
* @throws IllegalArgumentException if the tensor is a scalar or if {@code dst} is not compatible
* with the tensor (for example, mismatched data types or shapes).
*/
public <U> U copyTo(U dst) {
throwExceptionIfTypeIsIncompatible(dst);
readNDArray(getNativeHandle(), dst);
return dst;
}
/**
* Write the data of a {@link Integer} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
* @throws IllegalArgumentException If the tensor data type is not {@link Integer}
*/
public void writeTo(IntBuffer dst) {
if (dtype != DataType.INT32) {
throw incompatibleBuffer(dst, dtype);
}
ByteBuffer src = buffer();
dst.put(src.asIntBuffer());
}
/**
* Write the data of a {@link Float} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
* @throws IllegalArgumentException If the tensor datatype is not {@link Float}
*/
public void writeTo(FloatBuffer dst) {
if (dtype != DataType.FLOAT) {
throw incompatibleBuffer(dst, dtype);
}
ByteBuffer src = buffer();
dst.put(src.asFloatBuffer());
}
/**
* Write the data of a {@link Double} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
* @throws IllegalArgumentException If the tensor datatype is not {@link Double}
*/
public void writeTo(DoubleBuffer dst) {
if (dtype != DataType.DOUBLE) {
throw incompatibleBuffer(dst, dtype);
}
ByteBuffer src = buffer();
dst.put(src.asDoubleBuffer());
}
/**
* Write the data of a {@link Long} tensor into the given buffer.
*
* <p>Copies {@code numElements()} elements to the buffer.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
* @throws IllegalArgumentException If the tensor datatype is not {@link Long}
*/
public void writeTo(LongBuffer dst) {
if (dtype != DataType.INT64) {
throw incompatibleBuffer(dst, dtype);
}
ByteBuffer src = buffer();
dst.put(src.asLongBuffer());
}
/**
* Write the tensor data into the given buffer.
*
* <p>Copies {@code numBytes()} bytes to the buffer in native byte order for primitive types.
*
* @param dst the destination buffer
* @throws BufferOverflowException If there is insufficient space in the given buffer for the data
* in this tensor
*/
public void writeTo(ByteBuffer dst) {
ByteBuffer src = buffer();
dst.put(src);
}
/** Returns a string describing the type and shape of the Tensor. */
@Override
public String toString() {
return String.format("%s tensor with shape %s", dtype.toString(), Arrays.toString(shape()));
}
/*
public int elphel_isCUDATensor() {
int result = elphelIsCUDATensor(getNativeHandle());
return result;
}
*/
public long elphel_GetGPUTensorPointer(){
return elphelGetGPUTensorPointer(getNativeHandle());
}
/**
* Create a Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(long handle) {
@SuppressWarnings("rawtypes")
Tensor<?> t = new Tensor(DataType.fromC(dtype(handle)));
t.shapeCopy = shape(handle);
t.nativeRef = new NativeReference(handle);
return t;
}
/**
* Create an eager Tensor object from a handle to the C TF_Tensor object.
*
* <p>Takes ownership of the handle.
*/
static Tensor<?> fromHandle(long handle, EagerSession session) {
Tensor<?> t = fromHandle(handle);
t.nativeRef.eager(session, t);
return t;
}
long getNativeHandle() {
return nativeRef.tensorHandle;
}
private NativeReference nativeRef = null;
private final DataType dtype;
private long[] shapeCopy = null;
private Tensor(DataType t) {
dtype = t;
}
private ByteBuffer buffer() {
return buffer(getNativeHandle()).order(ByteOrder.nativeOrder());
}
private static IllegalArgumentException incompatibleBuffer(Buffer buf, DataType dataType) {
return new IllegalArgumentException(
String.format("cannot use %s with Tensor of type %s", buf.getClass().getName(), dataType));
}
private static IllegalArgumentException incompatibleBuffer(int numElements, long[] shape) {
return new IllegalArgumentException(
String.format(
"buffer with %d elements is not compatible with a Tensor with shape %s",
numElements, Arrays.toString(shape)));
}
private static int numElements(long[] shape) {
// assumes a fully-known shape
int n = 1;
for (int i = 0; i < shape.length; i++) {
n *= (int) shape[i];
}
return n;
}
private static int elemByteSize(DataType dataType) {
int size = dataType.byteSize();
if (size < 0) {
throw new IllegalArgumentException("STRING tensors do not have a fixed element size");
}
return size;
}
private static void throwExceptionIfNotByteOfByteArrays(Object array) {
if (!array.getClass().getName().equals("[[B")) {
throw new IllegalArgumentException(
"object cannot be converted to a Tensor as it includes an array with null elements");
}
}
/**
* Reference to the underlying native tensor
*
* <p>Tensors are commonly allocated in a `try-with-resources` statement, where they get
* automatically released after executing the last line of the `try` block they were declared in.
*
* <p>They can also be attached to an eager session, where in this case their lifetime ends either
* when this session is closed or when the Tensor instance is no longer referenced and have been
* garbage-collected.
*
* <p>This helper class wraps the tensor native handle and support both situations; If an eager
* reference to the tensor exists, it will take care of releasing the tensor at the end of its
* life. If the tensor is being explicitly closed before this happens, it will take cake of
* clearing its association with any eager session before cleaning up the resources.
*/
private static class NativeReference {
/** Attaches this reference to an eager session */
private class EagerReference extends EagerSession.NativeReference {
EagerReference(EagerSession session, Tensor<?> tensor) {
super(session, tensor);
}
@Override
void delete() {
// Mark this eager reference as cleared since it has been deleted by the session
NativeReference.this.eagerRef = null;
NativeReference.this.release();
}
}
NativeReference(long tensorHandle) {
this.tensorHandle = tensorHandle;
}
void eager(EagerSession session, Tensor<?> tensor) {
if (eagerRef != null) {
throw new IllegalStateException("The tensor is already attached to an eager session");
}
eagerRef = new EagerReference(session, tensor);
}
synchronized void release() {
if (tensorHandle != 0L) {
// Clear any remaining eager reference to this tensor
if (eagerRef != null) {
eagerRef.clear();
eagerRef = null;
}
Tensor.delete(tensorHandle);
tensorHandle = 0L;
}
}
private long tensorHandle;
private EagerReference eagerRef;
}
private static HashMap<Class<?>, DataType> classDataTypes = new HashMap<>();
static {
classDataTypes.put(int.class, DataType.INT32);
classDataTypes.put(Integer.class, DataType.INT32);
classDataTypes.put(long.class, DataType.INT64);
classDataTypes.put(Long.class, DataType.INT64);
classDataTypes.put(float.class, DataType.FLOAT);
classDataTypes.put(Float.class, DataType.FLOAT);
classDataTypes.put(double.class, DataType.DOUBLE);
classDataTypes.put(Double.class, DataType.DOUBLE);
classDataTypes.put(byte.class, DataType.STRING);
classDataTypes.put(Byte.class, DataType.STRING);
classDataTypes.put(boolean.class, DataType.BOOL);
classDataTypes.put(Boolean.class, DataType.BOOL);
}
/** The class for the data type to which Java object o corresponds. */
private static Class<?> baseObjType(Object o) {
Class<?> c = o.getClass();
while (c.isArray()) {
c = c.getComponentType();
}
return c;
}
/**
* The default TensorFlow data type to which Java object o corresponds. Some Java objects
* represent more than one TensorFlow data type; for example, 'byte' can represent both {@code
* uint8} and {@code string}, with the latter being the default interpretation.
*/
private static DataType dataTypeOf(Object o) {
Class<?> c = baseObjType(o);
return dataTypeFromClass(c);
}
private static DataType dataTypeFromClass(Class<?> c) {
DataType ret = classDataTypes.get(c);
if (ret != null) {
return ret;
}
throw new IllegalArgumentException("cannot create Tensors of type " + c.getName());
}
/**
* Return the number of dimensions of the tensor that object {@code o} represents as a tensor
* whose datatype is {@code dtype}. Normally this is the same as the number of dimensions of o
* itself, but is one smaller for tensors of strings.
*
* @param o The object to inspect. It must be a valid representation of the given data type.
* @param dtype The expected data type of the tensor.
*/
private static int numDimensions(Object o, DataType dtype) {
int ret = numArrayDimensions(o);
if (dtype == DataType.STRING && ret > 0) {
return ret - 1;
}
return ret;
}
/** Returns the number of dimensions of the array object o. Returns 0 if o is not an array. */
private static int numArrayDimensions(Object o) {
Class<?> c = o.getClass();
int i = 0;
while (c.isArray()) {
c = c.getComponentType();
i++;
}
return i;
}
/**
* Fills in the remaining entries in the shape array starting from position {@code dim} with the
* dimension sizes of the multidimensional array o. Checks that all arrays reachable from o have
* sizes consistent with the filled-in shape, throwing IllegalArgumentException otherwise.
*/
private static void fillShape(Object o, int dim, long[] shape) {
if (shape == null || dim == shape.length) {
return;
}
final int len = Array.getLength(o);
if (len == 0) {
throw new IllegalArgumentException("cannot create Tensors with a 0 dimension");
}
if (shape[dim] == 0) {
shape[dim] = len;
} else if (shape[dim] != len) {
throw new IllegalArgumentException(
String.format("mismatched lengths (%d and %d) in dimension %d", shape[dim], len, dim));
}
for (int i = 0; i < len; ++i) {
fillShape(Array.get(o, i), dim + 1, shape);
}
}
/** Returns whether the object {@code obj} can represent a tensor with data type {@code dtype}. */
private static boolean objectCompatWithType(Object obj, DataType dtype) {
Class<?> c = baseObjType(obj);
DataType dto = dataTypeFromClass(c);
int nd = numDimensions(obj, dto);
if (!c.isPrimitive() && c != String.class && nd != 0) {
throw new IllegalArgumentException(
"cannot create non-scalar Tensors from arrays of boxed values");
}
if (dto.equals(dtype)) {
return true;
}
if (dto == DataType.STRING && dtype == DataType.UINT8) {
return true;
}
return false;
}
private void throwExceptionIfTypeIsIncompatible(Object o) {
final int rank = numDimensions();
final int oRank = numDimensions(o, dtype);
if (oRank != rank) {
throw new IllegalArgumentException(
String.format(
"cannot copy Tensor with %d dimensions into an object with %d", rank, oRank));
}
if (!objectCompatWithType(o, dtype)) {
throw new IllegalArgumentException(
String.format(
"cannot copy Tensor with DataType %s into an object of type %s",
dtype.toString(), o.getClass().getName()));
}
long[] oShape = new long[rank];
fillShape(o, 0, oShape);
for (int i = 0; i < oShape.length; ++i) {
if (oShape[i] != shape()[i]) {
throw new IllegalArgumentException(
String.format(
"cannot copy Tensor with shape %s into object with shape %s",
Arrays.toString(shape()), Arrays.toString(oShape)));
}
}
}
private static native long allocate(int dtype, long[] shape, long byteSize);
private static native long allocateScalarBytes(byte[] value);
private static native long allocateNonScalarBytes(long[] shape, Object[] value);
private static native long elphelAllocateGPUTensor(long[] shape, int dtype);
private static native long elphelGetGPUTensorPointer(long handle);
private static native void delete(long handle);
private static native ByteBuffer buffer(long handle);
private static native int dtype(long handle);
private static native long[] shape(long handle);
private static native void setValue(long handle, Object value);
private static native float scalarFloat(long handle);
private static native double scalarDouble(long handle);
private static native int scalarInt(long handle);
private static native long scalarLong(long handle);
private static native boolean scalarBoolean(long handle);
private static native byte[] scalarBytes(long handle);
private static native void readNDArray(long handle, Object value);
//private static native int elphelIsCUDATensor(long handle);
//public static native int elphelTestCUDAPointer();
static {
TensorFlow.init();
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
package org.tensorflow;
/** Static utility methods describing the TensorFlow runtime. */
public final class TensorFlow {
/** Returns the version of the underlying TensorFlow runtime. */
public static native String version();
/** Returns the version of the underlying TensorFlow runtime. */
public static native String elphelVersion();
/**
* All the TensorFlow operations available in this address space.
*
* @return A serialized representation of an <a
* href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList</a>
* protocol buffer, which lists all the available TensorFlow operations.
*/
public static native byte[] registeredOpList();
/**
* Load the dynamic library in filename and register the operations and kernels present in that
* library.
*
* @param filename Path of the dynamic library containing operations and kernels to load.
* @return Serialized bytes of the <a
* href="https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto">OpList</a>
* protocol buffer message defining the operations defined in the library.
* @throws UnsatisfiedLinkError if filename cannot be loaded.
*/
public static byte[] loadLibrary(String filename) {
long h = 0;
try {
h = libraryLoad(filename);
} catch (RuntimeException e) {
throw new UnsatisfiedLinkError(e.getMessage());
}
try {
return libraryOpList(h);
} finally {
libraryDelete(h);
}
}
private static native long libraryLoad(String filename);
private static native void libraryDelete(long handle);
private static native byte[] libraryOpList(long handle);
private TensorFlow() {}
/** Load the TensorFlow runtime C library. */
static void init() {
try {
NativeLibrary.load();
} catch (Exception e) {
/*
* This code is called during static initialization of this and of other classes.
* If this fails then a NoClassDefFoundError is thrown however this does not
* include a cause. Printing the exception manually here ensures that the
* necessary information to fix the problem is available.
*/
System.err.println("Failed to load TensorFlow native library");
e.printStackTrace();
throw e;
}
}
static {
init();
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string.h>
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/java/src/main/native/utils_jni.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include "tensorflow/java/src/main/native/session_jni.h"
#include "tensorflow/core/common_runtime/direct_session.h"
//#include "tensorflow/cc/client/client_session.h"
#include "cuda_runtime_api.h"
namespace {
TF_Session* requireHandle(JNIEnv* env, jlong handle) {
static_assert(sizeof(jlong) >= sizeof(TF_Session*),
"Cannot package C object pointers as a Java long");
if (handle == 0) {
throwException(env, kNullPointerException,
"close() has been called on the Session");
return nullptr;
}
return reinterpret_cast<TF_Session*>(handle);
}
template <class T>
void resolveHandles(JNIEnv* env, const char* type, jlongArray src_array,
T** dst, jint n) {
if (env->ExceptionCheck()) return;
jint len = env->GetArrayLength(src_array);
if (len != n) {
throwException(env, kIllegalArgumentException, "expected %d, got %d %s", n,
len, type);
return;
}
jlong* src_start = env->GetLongArrayElements(src_array, nullptr);
jlong* src = src_start;
for (int i = 0; i < n; ++i, ++src, ++dst) {
if (*src == 0) {
throwException(env, kNullPointerException, "invalid %s (#%d of %d)", type,
i, n);
break;
}
*dst = reinterpret_cast<T*>(*src);
}
env->ReleaseLongArrayElements(src_array, src_start, JNI_ABORT);
}
void TF_MaybeDeleteBuffer(TF_Buffer* buf) {
if (buf == nullptr) return;
TF_DeleteBuffer(buf);
}
typedef std::unique_ptr<TF_Buffer, decltype(&TF_MaybeDeleteBuffer)>
unique_tf_buffer;
unique_tf_buffer MakeUniqueBuffer(TF_Buffer* buf) {
return unique_tf_buffer(buf, TF_MaybeDeleteBuffer);
}
} // namespace
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(
JNIEnv* env, jclass clazz, jlong graph_handle) {
return Java_org_tensorflow_Session_allocate2(env, clazz, graph_handle,
nullptr, nullptr);
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(
JNIEnv* env, jclass clazz, jlong graph_handle, jstring target,
jbyteArray config) {
if (graph_handle == 0) {
throwException(env, kNullPointerException, "Graph has been close()d");
return 0;
}
TF_Graph* graph = reinterpret_cast<TF_Graph*>(graph_handle);
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
jbyte* cconfig = nullptr;
if (config != nullptr) {
cconfig = env->GetByteArrayElements(config, nullptr);
TF_SetConfig(opts, cconfig,
static_cast<size_t>(env->GetArrayLength(config)), status);
if (!throwExceptionIfNotOK(env, status)) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
TF_DeleteSessionOptions(opts);
TF_DeleteStatus(status);
return 0;
}
}
const char* ctarget = nullptr;
if (target != nullptr) {
ctarget = env->GetStringUTFChars(target, nullptr);
}
TF_Session* session = TF_NewSession(graph, opts, status);
if (config != nullptr) {
env->ReleaseByteArrayElements(config, cconfig, JNI_ABORT);
}
if (target != nullptr) {
env->ReleaseStringUTFChars(target, ctarget);
}
TF_DeleteSessionOptions(opts);
bool ok = throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
return ok ? reinterpret_cast<jlong>(session) : 0;
}
JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Session* session = requireHandle(env, handle);
if (session == nullptr) return;
TF_Status* status = TF_NewStatus();
TF_CloseSession(session, status);
// Result of close is ignored, delete anyway.
TF_DeleteSession(session, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
}
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
JNIEnv* env, jclass clazz, jlong handle, jbyteArray jrun_options,
jlongArray input_tensor_handles, jlongArray input_op_handles,
jintArray input_op_indices, jlongArray output_op_handles,
jintArray output_op_indices, jlongArray target_op_handles,
jboolean want_run_metadata, jlongArray output_tensor_handles) {
TF_Session* session = requireHandle(env, handle);
if (session == nullptr) return nullptr;
const jint ninputs = env->GetArrayLength(input_tensor_handles);
const jint noutputs = env->GetArrayLength(output_tensor_handles);
const jint ntargets = env->GetArrayLength(target_op_handles);
std::unique_ptr<TF_Output[]> inputs(new TF_Output[ninputs]);
std::unique_ptr<TF_Tensor* []> input_values(new TF_Tensor*[ninputs]);
std::unique_ptr<TF_Output[]> outputs(new TF_Output[noutputs]);
std::unique_ptr<TF_Tensor* []> output_values(new TF_Tensor*[noutputs]);
std::unique_ptr<TF_Operation* []> targets(new TF_Operation*[ntargets]);
unique_tf_buffer run_metadata(
MakeUniqueBuffer(want_run_metadata ? TF_NewBuffer() : nullptr));
resolveHandles(env, "input Tensors", input_tensor_handles, input_values.get(),
ninputs);
resolveOutputs(env, "input", input_op_handles, input_op_indices, inputs.get(),
ninputs);
resolveOutputs(env, "output", output_op_handles, output_op_indices,
outputs.get(), noutputs);
resolveHandles(env, "target Operations", target_op_handles, targets.get(),
ntargets);
if (env->ExceptionCheck()) return nullptr;
TF_Status* status = TF_NewStatus();
unique_tf_buffer run_options(MakeUniqueBuffer(nullptr));
jbyte* jrun_options_data = nullptr;
if (jrun_options != nullptr) {
size_t sz = env->GetArrayLength(jrun_options);
if (sz > 0) {
jrun_options_data = env->GetByteArrayElements(jrun_options, nullptr);
run_options.reset(
TF_NewBufferFromString(static_cast<void*>(jrun_options_data), sz));
}
}
TF_SessionRun(session, run_options.get(), inputs.get(), input_values.get(),
static_cast<int>(ninputs), outputs.get(), output_values.get(),
static_cast<int>(noutputs), targets.get(),
static_cast<int>(ntargets), run_metadata.get(), status);
if (jrun_options_data != nullptr) {
env->ReleaseByteArrayElements(jrun_options, jrun_options_data, JNI_ABORT);
}
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return nullptr;
}
jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
for (int i = 0; i < noutputs; ++i) {
t[i] = reinterpret_cast<jlong>(output_values[i]);
}
env->ReleaseLongArrayElements(output_tensor_handles, t, 0);
jbyteArray ret = nullptr;
if (run_metadata != nullptr) {
ret = env->NewByteArray(run_metadata->length);
env->SetByteArrayRegion(ret, 0, run_metadata->length,
reinterpret_cast<const jbyte*>(run_metadata->data));
}
TF_DeleteStatus(status);
return ret;
}
// Create an empty tensor of type 'dtype'. 'shape' can be arbitrary, but has to
// result in a zero-sized tensor.
static TF_Tensor* EmptyTensor(TF_DataType dtype,
const tensorflow::TensorShape& shape) {
static char empty;
tensorflow::int64 nelems = 1;
std::vector<tensorflow::int64> dims;
for (int i = 0; i < shape.dims(); ++i) {
dims.push_back(shape.dim_size(i));
nelems *= shape.dim_size(i);
}
CHECK_EQ(nelems, 0);
static_assert(sizeof(int64_t) == sizeof(tensorflow::int64),
"64-bit int types should match in size");
return TF_NewTensor(
dtype, reinterpret_cast<const int64_t*>(dims.data()), shape.dims(),
reinterpret_cast<void*>(&empty), 0, [](void*, size_t, void*) {}, nullptr);
}
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_elphelRunCallable(
JNIEnv* env, jclass clazz,
jlong session_handle, jlong callable_handle,
jlongArray input_tensor_handles,
jlongArray output_tensor_handles) {
//printf("Running Callable\n");
TF_Session* session = requireHandle(env, session_handle);
using namespace tensorflow;
Session::CallableHandle feed_gpu_fetch_cpu = (Session::CallableHandle) reinterpret_cast<long>(callable_handle);
const jint ninputs = env->GetArrayLength(input_tensor_handles);
const jint noutputs = env->GetArrayLength(output_tensor_handles);
//printf("ninputs: %d, noutputs: %d\n",ninputs, noutputs);
std::unique_ptr<TF_Tensor* []> output_values(new TF_Tensor*[noutputs]);
std::unique_ptr<TF_Tensor* []> input_values(new TF_Tensor*[ninputs]);
// from input tensor handles to inputs?
resolveHandles(env, "input Tensors", input_tensor_handles, input_values.get(), ninputs);
std::vector<Tensor> inputs(ninputs);
for (int i=0; i<ninputs; ++i) {
TF_TensorToTensor(input_values[i],&inputs[i]);
}
// figure out how to create stuff from handles
std::vector<Tensor> outputs(noutputs);
auto runStatus = session->session->RunCallable(feed_gpu_fetch_cpu, {inputs}, &outputs, nullptr);
if (!runStatus.ok()){
printf("It is with a heavy heart I inform you that RunCallable has failed. Here's the error message:\n");
printf(runStatus.error_message().c_str());
return nullptr;
}
// get the handles t
jlong* t = env->GetLongArrayElements(output_tensor_handles, nullptr);
TF_Status* status = TF_NewStatus();
for (int i = 0; i < noutputs; ++i) {
//outputs[i] = inputz[i];
const Tensor& src = outputs[i];
/*
std::cout << src.DebugString() << std::endl;
// print values:
std::cout << "Output tensor (printing from session_jni.cc):";
auto tmap = src.tensor<float, 1>();
for (int d = 0; d < 256; d++){
std::cout << (int) tmap(d);
if (d!=255) std::cout << ", ";
}
*/
//output_values[i]->tensor = outputs[i];
if (!src.IsInitialized() || src.NumElements() == 0) {
output_values[i] = EmptyTensor(static_cast<TF_DataType>(src.dtype()), src.shape());
continue;
}
output_values[i] = TF_TensorFromTensor(src, status);
// for whatever reason status cannot be a nullptr here
//output_values[i] = TF_TensorFromTensor(src, nullptr);
t[i] = reinterpret_cast<jlong>(output_values[i]);
}
// this copies back the updated array andit can be accessed up there in Java
env->ReleaseLongArrayElements(output_tensor_handles, t, 0);
jbyteArray ret = nullptr;
return ret;
}
JNIEXPORT jstring JNICALL Java_org_tensorflow_Session_elphelGetGPUDeviceName(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Session* session = requireHandle(env, handle);
if (session == nullptr) return env->NewStringUTF("");
using namespace tensorflow;
std::vector<DeviceAttributes> devices;
TF_CHECK_OK(session->session->ListDevices(&devices));
for (const DeviceAttributes& d : devices) {
LOG(INFO) << "Device: " << d.name();
if (d.device_type() == "GPU" || d.device_type() == "gpu") {
return env->NewStringUTF(d.name().c_str());
}
}
return env->NewStringUTF("");
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_elphelMakeCallable(
JNIEnv* env, jclass clazz, jlong session_handle, jbyteArray config){
TF_Session* session = requireHandle(env, session_handle);
using namespace tensorflow;
CallableOptions opts;
jbyte* cconfig = nullptr;
if (config != nullptr) {
cconfig = env->GetByteArrayElements(config, nullptr);
opts.ParseFromArray(cconfig, static_cast<size_t>(env->GetArrayLength(config)));
}
Session::CallableHandle feed_gpu_fetch_cpu;
auto runStatus = session->session->MakeCallable(opts, &feed_gpu_fetch_cpu);
if (!runStatus.ok()){
printf("It is with a heavy heart I inform you that MakeCallable has failed. Here's the error message:\n");
printf(runStatus.error_message().c_str());
return -1;
}else{
/*
jlong* t = env->GetLongArrayElements(callable_handle, nullptr);
t[0] = reinterpret_cast<jlong>((long) feed_gpu_fetch_cpu);
env->ReleaseLongArrayElements(callable_handle, t, 0);
*/
return reinterpret_cast<jlong>((long) feed_gpu_fetch_cpu);
}
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_tensorflow_Session
* Method: allocate
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Session
* Method: allocate2
* Signature: (JLjava/lang/String;[B)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_allocate2(JNIEnv *, jclass,
jlong, jstring,
jbyteArray);
/*
* Class: org_tensorflow_Session
* Method: delete
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_Session_delete(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Session
* Method: run
* Signature: (J[B[J[J[I[J[I[JZ[J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_run(
JNIEnv *, jclass, jlong, jbyteArray, jlongArray, jlongArray, jintArray,
jlongArray, jintArray, jlongArray, jboolean, jlongArray);
JNIEXPORT jstring JNICALL Java_org_tensorflow_Session_elphelGetGPUDeviceName(
JNIEnv*, jclass, jlong handle);
JNIEXPORT jlong JNICALL Java_org_tensorflow_Session_elphelMakeCallable(
JNIEnv*, jclass, jlong, jbyteArray);
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Session_elphelRunCallable(
JNIEnv*, jclass, jlong, jlong, jlongArray, jlongArray);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_SESSION_JNI_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/java/src/main/native/tensor_jni.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>
#include <memory>
#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
#include <cuda_runtime_api.h>
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/framework/tensor.h"
// GPU allocator
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h"
using tensorflow::Tensor;
//#include "tensorflow/core/common_runtime/direct_session.h"
namespace {
TF_Tensor* requireHandle(JNIEnv* env, jlong handle) {
if (handle == 0) {
throwException(env, kNullPointerException,
"close() was called on the Tensor");
return nullptr;
}
return reinterpret_cast<TF_Tensor*>(handle);
}
size_t elemByteSize(TF_DataType dtype) {
// The code in this file makes the assumption that the
// TensorFlow TF_DataTypes and the Java primitive types
// have the same byte sizes. Validate that:
switch (dtype) {
case TF_BOOL:
case TF_UINT8:
static_assert(sizeof(jboolean) == 1,
"Java boolean not compatible with TF_BOOL");
static_assert(sizeof(jbyte) == 1,
"Java byte not compatible with TF_UINT8");
return 1;
case TF_FLOAT:
case TF_INT32:
static_assert(sizeof(jfloat) == 4,
"Java float not compatible with TF_FLOAT");
static_assert(sizeof(jint) == 4, "Java int not compatible with TF_INT32");
return 4;
case TF_DOUBLE:
case TF_INT64:
static_assert(sizeof(jdouble) == 8,
"Java double not compatible with TF_DOUBLE");
static_assert(sizeof(jlong) == 8,
"Java long not compatible with TF_INT64");
return 8;
default:
return 0;
}
}
// Write a Java scalar object (java.lang.Integer etc.) to a TF_Tensor.
void writeScalar(JNIEnv* env, jobject src, TF_DataType dtype, void* dst,
size_t dst_size) {
size_t sz = elemByteSize(dtype);
if (sz != dst_size) {
throwException(
env, kIllegalStateException,
"scalar (%d bytes) not compatible with allocated tensor (%d bytes)", sz,
dst_size);
return;
}
switch (dtype) {
// env->FindClass and env->GetMethodID are expensive and JNI best practices
// suggest that they should be cached. However, until the creation of scalar
// valued tensors seems to become a noticeable fraction of program execution,
// ignore that cost.
#define CASE(dtype, jtype, method_name, method_signature, call_type) \
case dtype: { \
jclass clazz = env->FindClass("java/lang/Number"); \
jmethodID method = env->GetMethodID(clazz, method_name, method_signature); \
jtype v = env->Call##call_type##Method(src, method); \
memcpy(dst, &v, sz); \
return; \
}
CASE(TF_FLOAT, jfloat, "floatValue", "()F", Float);
CASE(TF_DOUBLE, jdouble, "doubleValue", "()D", Double);
CASE(TF_INT32, jint, "intValue", "()I", Int);
CASE(TF_INT64, jlong, "longValue", "()J", Long);
CASE(TF_UINT8, jbyte, "byteValue", "()B", Byte);
#undef CASE
case TF_BOOL: {
jclass clazz = env->FindClass("java/lang/Boolean");
jmethodID method = env->GetMethodID(clazz, "booleanValue", "()Z");
jboolean v = env->CallBooleanMethod(src, method);
*(static_cast<unsigned char*>(dst)) = v ? 1 : 0;
return;
}
default:
throwException(env, kIllegalStateException, "invalid DataType(%d)",
dtype);
return;
}
}
// Copy a 1-D array of Java primitive types to the tensor buffer dst.
// Returns the number of bytes written to dst.
size_t write1DArray(JNIEnv* env, jarray array, TF_DataType dtype, void* dst,
size_t dst_size) {
const int nelems = env->GetArrayLength(array);
jboolean is_copy;
switch (dtype) {
#define CASE(dtype, jtype, get_type) \
case dtype: { \
jtype##Array a = static_cast<jtype##Array>(array); \
jtype* values = env->Get##get_type##ArrayElements(a, &is_copy); \
size_t to_copy = nelems * elemByteSize(dtype); \
if (to_copy > dst_size) { \
throwException( \
env, kIllegalStateException, \
"cannot write Java array of %d bytes to Tensor of %d bytes", \
to_copy, dst_size); \
to_copy = 0; \
} else { \
memcpy(dst, values, to_copy); \
} \
env->Release##get_type##ArrayElements(a, values, JNI_ABORT); \
return to_copy; \
}
CASE(TF_FLOAT, jfloat, Float);
CASE(TF_DOUBLE, jdouble, Double);
CASE(TF_INT32, jint, Int);
CASE(TF_INT64, jlong, Long);
CASE(TF_BOOL, jboolean, Boolean);
CASE(TF_UINT8, jbyte, Byte);
#undef CASE
default:
throwException(env, kIllegalStateException, "invalid DataType(%d)",
dtype);
return 0;
}
}
// Copy the elements of a 1-D array from the tensor buffer src to a 1-D array of
// Java primitive types. Returns the number of bytes read from src.
size_t read1DArray(JNIEnv* env, TF_DataType dtype, const void* src,
size_t src_size, jarray dst) {
const int len = env->GetArrayLength(dst);
const size_t sz = len * elemByteSize(dtype);
if (sz > src_size) {
throwException(
env, kIllegalStateException,
"cannot fill a Java array of %d bytes with a Tensor of %d bytes", sz,
src_size);
return 0;
}
switch (dtype) {
#define CASE(dtype, jtype, primitive_type) \
case dtype: { \
jtype##Array arr = static_cast<jtype##Array>(dst); \
env->Set##primitive_type##ArrayRegion(arr, 0, len, \
static_cast<const jtype*>(src)); \
return sz; \
}
CASE(TF_FLOAT, jfloat, Float);
CASE(TF_DOUBLE, jdouble, Double);
CASE(TF_INT32, jint, Int);
CASE(TF_INT64, jlong, Long);
CASE(TF_BOOL, jboolean, Boolean);
CASE(TF_UINT8, jbyte, Byte);
#undef CASE
default:
throwException(env, kIllegalStateException, "invalid DataType(%d)",
dtype);
}
return 0;
}
size_t writeNDArray(JNIEnv* env, jarray src, TF_DataType dtype, int dims_left,
char* dst, size_t dst_size) {
if (dims_left == 1) {
return write1DArray(env, src, dtype, dst, dst_size);
} else {
jobjectArray ndarray = static_cast<jobjectArray>(src);
int len = env->GetArrayLength(ndarray);
size_t sz = 0;
for (int i = 0; i < len; ++i) {
jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
sz +=
writeNDArray(env, row, dtype, dims_left - 1, dst + sz, dst_size - sz);
env->DeleteLocalRef(row);
if (env->ExceptionCheck()) return sz;
}
return sz;
}
}
size_t readNDArray(JNIEnv* env, TF_DataType dtype, const char* src,
size_t src_size, int dims_left, jarray dst) {
if (dims_left == 1) {
return read1DArray(env, dtype, src, src_size, dst);
} else {
jobjectArray ndarray = static_cast<jobjectArray>(dst);
int len = env->GetArrayLength(ndarray);
size_t sz = 0;
for (int i = 0; i < len; ++i) {
jarray row = static_cast<jarray>(env->GetObjectArrayElement(ndarray, i));
sz +=
readNDArray(env, dtype, src + sz, src_size - sz, dims_left - 1, row);
env->DeleteLocalRef(row);
if (env->ExceptionCheck()) return sz;
}
return sz;
}
}
jbyteArray TF_StringDecodeTojbyteArray(JNIEnv* env, const char* src,
size_t src_len, TF_Status* status) {
const char* dst = nullptr;
size_t dst_len = 0;
TF_StringDecode(src, src_len, &dst, &dst_len, status);
if (TF_GetCode(status) != TF_OK) {
return nullptr;
}
jbyteArray ret = env->NewByteArray(dst_len);
jbyte* cpy = env->GetByteArrayElements(ret, nullptr);
memcpy(cpy, dst, dst_len);
env->ReleaseByteArrayElements(ret, cpy, 0);
return ret;
}
class StringTensorWriter {
public:
StringTensorWriter(TF_Tensor* t, int num_elements)
: offset_(0),
poffsets_(static_cast<char*>(TF_TensorData(t))),
pdata_(poffsets_ + 8 * num_elements),
plimit_(poffsets_ + TF_TensorByteSize(t)) {}
void Add(const char* src, size_t len, TF_Status* status) {
if (TF_GetCode(status) != TF_OK) return;
if (plimit_ - poffsets_ < sizeof(offset_)) {
TF_SetStatus(status, TF_OUT_OF_RANGE,
"TF_STRING tensor encoding ran out of space for offsets, "
"this is likely a bug, please file an issue at "
"https://github.com/tensorflow/tensorflow/issues/new");
return;
}
memcpy(poffsets_, &offset_, sizeof(offset_));
size_t written =
TF_StringEncode(src, len, pdata_, (plimit_ - pdata_), status);
offset_ += written;
poffsets_ += 8;
pdata_ += written;
}
private:
uint64_t offset_;
char* poffsets_;
char* pdata_;
const char* plimit_;
};
class StringTensorReader {
public:
StringTensorReader(const TF_Tensor* t, int num_elements)
: index_(0),
offsets_(static_cast<const char*>(TF_TensorData(t))),
data_(offsets_ + 8 * num_elements),
limit_(offsets_ + TF_TensorByteSize(t)) {}
jbyteArray Next(JNIEnv* env, TF_Status* status) {
if (TF_GetCode(status) != TF_OK) return nullptr;
uint64_t offset = 0;
const char* poffset = offsets_ + sizeof(offset) * index_;
if (poffset >= limit_) {
TF_SetStatus(
status, TF_INTERNAL,
"Invalid TF_STRING tensor, offsets table seems to be too small");
return nullptr;
}
memcpy(&offset, poffset, sizeof(offset));
const char* pdata = data_ + offset;
if (pdata >= limit_) {
TF_SetStatus(status, TF_INTERNAL,
"Invalid TF_STRING tensor, invalid entry in offset table");
return nullptr;
}
++index_;
return TF_StringDecodeTojbyteArray(env, pdata, (limit_ - pdata), status);
}
private:
int index_;
const char* offsets_;
const char* data_;
const char* limit_;
};
void readNDStringArray(JNIEnv* env, StringTensorReader* reader, int dims_left,
jobjectArray dst, TF_Status* status) {
jsize len = env->GetArrayLength(dst);
if (dims_left == 1) {
for (jsize i = 0; i < len; ++i) {
jbyteArray elem = reader->Next(env, status);
if (TF_GetCode(status) != TF_OK) return;
env->SetObjectArrayElement(dst, i, elem);
}
return;
}
for (jsize i = 0; i < len; ++i) {
jobjectArray arr =
static_cast<jobjectArray>(env->GetObjectArrayElement(dst, i));
readNDStringArray(env, reader, dims_left - 1, arr, status);
if (TF_GetCode(status) != TF_OK) return;
}
}
} // namespace
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv* env,
jclass clazz,
jint dtype,
jlongArray shape,
jlong sizeInBytes) {
int num_dims = static_cast<int>(env->GetArrayLength(shape));
jlong* dims = nullptr;
if (num_dims > 0) {
jboolean is_copy;
dims = env->GetLongArrayElements(shape, &is_copy);
}
static_assert(sizeof(jlong) == sizeof(int64_t),
"Java long is not compatible with the TensorFlow C API");
// On some platforms "jlong" is a "long" while "int64_t" is a "long long".
//
// Thus, static_cast<int64_t*>(dims) will trigger a compiler error:
// static_cast from 'jlong *' (aka 'long *') to 'int64_t *' (aka 'long long
// *') is not allowed
//
// Since this array is typically very small, use the guaranteed safe scheme of
// creating a copy.
int64_t* dims_copy = new int64_t[num_dims];
for (int i = 0; i < num_dims; ++i) {
dims_copy[i] = static_cast<int64_t>(dims[i]);
}
TF_Tensor* t = TF_AllocateTensor(static_cast<TF_DataType>(dtype), dims_copy,
num_dims, static_cast<size_t>(sizeInBytes));
delete[] dims_copy;
if (dims != nullptr) {
env->ReleaseLongArrayElements(shape, dims, JNI_ABORT);
}
if (t == nullptr) {
throwException(env, kNullPointerException,
"unable to allocate memory for the Tensor");
return 0;
}
return reinterpret_cast<jlong>(t);
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_elphelAllocateGPUTensor(JNIEnv* env,
jclass clazz,
jlongArray shape,
jint dtype) {
TF_Tensor* t;
//t->tensor
using namespace tensorflow;
DataType dt_dtype = static_cast<DataType>(dtype);
// Actually, don't need TF_*
//TF_DataType tf_dtype = static_cast<TF_DataType>(dtype);
//size_t tf_dtype_size = TF_DataTypeSize(tf_dtype);
const int num_dims = static_cast<int>(env->GetArrayLength(shape));
//int64_t* dims = new int64_t[num_dims];
std::vector<tensorflow::int64> dims(num_dims);
int64_t num_elements = 1;
{
jlong* jdims = env->GetLongArrayElements(shape, nullptr);
for (int i = 0; i < num_dims; ++i) {
dims[i] = static_cast<int64>(jdims[i]);
num_elements *= dims[i];
}
// what's this for?
env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
}
TensorShape ts_shape = tensorflow::TensorShape(dims);
tensorflow::PlatformGpuId platform_gpu_id(0);
tensorflow::GPUMemAllocator *sub_allocator =
new tensorflow::GPUMemAllocator(
tensorflow::GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
platform_gpu_id, false, {}, {});
tensorflow::GPUBFCAllocator *allocator =
new tensorflow::GPUBFCAllocator(sub_allocator, num_elements * sizeof(dt_dtype), "GPU_0_bfc");
Tensor t_cuda = Tensor(allocator, dt_dtype, ts_shape);
//TODO:
// Maybe check tensor pointer here - CUDA or not CUDA?
//t->tensor = t_cuda;
TF_Status* status = TF_NewStatus();
// TODO: Check what exactly this function does...
t = TF_TensorFromTensor(t_cuda,status);
//printf("Allocated in GPU!");
return reinterpret_cast<jlong>(t);
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_elphelGetGPUTensorPointer(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return -1;
//using namespace tensorflow;
return reinterpret_cast<jlong>(t->tensor.tensor_data().data());
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateScalarBytes(
JNIEnv* env, jclass clazz, jbyteArray value) {
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by
// TF_StringEncode-encoded bytes.
size_t src_len = static_cast<int>(env->GetArrayLength(value));
size_t dst_len = TF_StringEncodedSize(src_len);
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, 8 + dst_len);
char* dst = static_cast<char*>(TF_TensorData(t));
memset(dst, 0, 8); // The offset table
TF_Status* status = TF_NewStatus();
jbyte* jsrc = env->GetByteArrayElements(value, nullptr);
// jsrc is an unsigned byte*, TF_StringEncode requires a char*.
// reinterpret_cast<> for this conversion should be safe.
TF_StringEncode(reinterpret_cast<const char*>(jsrc), src_len, dst + 8,
dst_len, status);
env->ReleaseByteArrayElements(value, jsrc, JNI_ABORT);
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteStatus(status);
return 0;
}
TF_DeleteStatus(status);
return reinterpret_cast<jlong>(t);
}
namespace {
size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) {
if (num_dims == 0) {
// This is the last dimension, i.e., value should correspond to a jbyteArray
// encoding the string.
return TF_StringEncodedSize(
static_cast<size_t>(env->GetArrayLength(value)));
}
jsize len = env->GetArrayLength(value);
size_t ret = 0;
for (jsize i = 0; i < len; ++i) {
jarray elem = static_cast<jarray>(
env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
if (elem == nullptr) {
throwException(env, kNullPointerException,
"null entries in provided array");
return ret;
}
ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1);
if (env->ExceptionCheck()) return ret;
}
return ret;
}
void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims,
StringTensorWriter* writer,
TF_Status* status) {
if (num_dims == 0) {
jbyte* jsrc =
env->GetByteArrayElements(static_cast<jbyteArray>(value), nullptr);
writer->Add(reinterpret_cast<const char*>(jsrc), env->GetArrayLength(value),
status);
env->ReleaseByteArrayElements(static_cast<jbyteArray>(value), jsrc,
JNI_ABORT);
return;
}
jsize len = env->GetArrayLength(value);
for (jsize i = 0; i < len; ++i) {
jarray elem = static_cast<jarray>(
env->GetObjectArrayElement(static_cast<jobjectArray>(value), i));
fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status);
if (TF_GetCode(status) != TF_OK) return;
}
}
} // namespace
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
JNIEnv* env, jclass clazz, jlongArray shape, jobjectArray value) {
// TF_STRING tensors are encoded with a table of 8-byte offsets following by
// TF_StringEncode-encoded bytes.
const int num_dims = static_cast<int>(env->GetArrayLength(shape));
int64_t* dims = new int64_t[num_dims];
int64_t num_elements = 1;
{
jlong* jdims = env->GetLongArrayElements(shape, nullptr);
for (int i = 0; i < num_dims; ++i) {
dims[i] = static_cast<int64_t>(jdims[i]);
num_elements *= dims[i];
}
env->ReleaseLongArrayElements(shape, jdims, JNI_ABORT);
}
const size_t encoded_size =
nonScalarTF_STRINGTensorSize(env, value, num_dims);
if (env->ExceptionCheck()) return 0;
TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims,
8 * num_elements + encoded_size);
if (t == nullptr) {
delete[] dims;
throwException(env, kNullPointerException,
"unable to allocate memory for the Tensor");
return 0;
}
TF_Status* status = TF_NewStatus();
StringTensorWriter writer(t, num_elements);
fillNonScalarTF_STRINGTensorData(env, value, num_dims, &writer, status);
delete[] dims;
jlong ret = 0;
if (!throwExceptionIfNotOK(env, status)) {
TF_DeleteTensor(t);
} else {
ret = reinterpret_cast<jlong>(t);
}
TF_DeleteStatus(status);
return ret;
}
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv* env,
jclass clazz,
jlong handle) {
if (handle == 0) return;
TF_DeleteTensor(reinterpret_cast<TF_Tensor*>(handle));
}
JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return nullptr;
void* data = TF_TensorData(t);
const size_t sz = TF_TensorByteSize(t);
return env->NewDirectByteBuffer(data, static_cast<jlong>(sz));
}
JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv* env,
jclass clazz,
jlong handle) {
static_assert(sizeof(jint) >= sizeof(TF_DataType),
"TF_DataType in C cannot be represented as an int in Java");
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return 0;
return static_cast<jint>(TF_TensorType(t));
}
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv* env,
jclass clazz,
jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return nullptr;
static_assert(sizeof(jlong) == sizeof(int64_t),
"Java long is not compatible with the TensorFlow C API");
const jsize num_dims = TF_NumDims(t);
jlongArray ret = env->NewLongArray(num_dims);
jlong* dims = env->GetLongArrayElements(ret, nullptr);
for (int i = 0; i < num_dims; ++i) {
dims[i] = static_cast<jlong>(TF_Dim(t, i));
}
env->ReleaseLongArrayElements(ret, dims, 0);
return ret;
}
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
void* data = TF_TensorData(t);
const size_t sz = TF_TensorByteSize(t);
if (num_dims == 0) {
writeScalar(env, value, dtype, data, sz);
} else {
writeNDArray(env, static_cast<jarray>(value), dtype, num_dims,
static_cast<char*>(data), sz);
}
}
#define DEFINE_GET_SCALAR_METHOD(jtype, dtype, method_suffix) \
JNIEXPORT jtype JNICALL Java_org_tensorflow_Tensor_scalar##method_suffix( \
JNIEnv* env, jclass clazz, jlong handle) { \
jtype ret = 0; \
TF_Tensor* t = requireHandle(env, handle); \
if (t == nullptr) return ret; \
if (TF_NumDims(t) != 0) { \
throwException(env, kIllegalStateException, "Tensor is not a scalar"); \
} else if (TF_TensorType(t) != dtype) { \
throwException(env, kIllegalStateException, "Tensor is not a %s scalar", \
#method_suffix); \
} else { \
memcpy(&ret, TF_TensorData(t), elemByteSize(dtype)); \
} \
return ret; \
}
DEFINE_GET_SCALAR_METHOD(jfloat, TF_FLOAT, Float);
DEFINE_GET_SCALAR_METHOD(jdouble, TF_DOUBLE, Double);
DEFINE_GET_SCALAR_METHOD(jint, TF_INT32, Int);
DEFINE_GET_SCALAR_METHOD(jlong, TF_INT64, Long);
DEFINE_GET_SCALAR_METHOD(jboolean, TF_BOOL, Boolean);
#undef DEFINE_GET_SCALAR_METHOD
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(
JNIEnv* env, jclass clazz, jlong handle) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return nullptr;
if (TF_NumDims(t) != 0) {
throwException(env, kIllegalStateException, "Tensor is not a scalar");
return nullptr;
}
if (TF_TensorType(t) != TF_STRING) {
throwException(env, kIllegalArgumentException,
"Tensor is not a string/bytes scalar");
return nullptr;
}
const char* data = static_cast<const char*>(TF_TensorData(t));
const char* src = data + 8;
size_t src_len = TF_TensorByteSize(t) - 8;
uint64_t offset = 0;
memcpy(&offset, data, sizeof(offset));
if (offset >= src_len) {
throwException(env, kIllegalArgumentException,
"invalid tensor encoding: bad offsets");
return nullptr;
}
TF_Status* status = TF_NewStatus();
jbyteArray ret = TF_StringDecodeTojbyteArray(env, src, src_len, status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
return ret;
}
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv* env,
jclass clazz,
jlong handle,
jobject value) {
TF_Tensor* t = requireHandle(env, handle);
if (t == nullptr) return;
int num_dims = TF_NumDims(t);
TF_DataType dtype = TF_TensorType(t);
const void* data = TF_TensorData(t);
const size_t sz = TF_TensorByteSize(t);
if (num_dims == 0) {
throwException(env, kIllegalArgumentException,
"copyTo() is not meant for scalar Tensors, use the scalar "
"accessor (floatValue(), intValue() etc.) instead");
return;
}
if (dtype == TF_STRING) {
int64_t num_elements = 1;
for (int i = 0; i < num_dims; ++i) {
num_elements *= TF_Dim(t, i);
}
StringTensorReader reader(t, num_elements);
TF_Status* status = TF_NewStatus();
readNDStringArray(env, &reader, num_dims, static_cast<jobjectArray>(value),
status);
throwExceptionIfNotOK(env, status);
TF_DeleteStatus(status);
return;
}
readNDArray(env, dtype, static_cast<const char*>(data), sz, num_dims,
static_cast<jarray>(value));
}
/*
JNIEXPORT int JNICALL Java_org_tensorflow_Tensor_elphelIsCUDATensor(JNIEnv* env,
jclass clazz,
jlong handle) {
if (handle == 0) return -1;
TF_Tensor* tf_t = requireHandle(env, handle);
Tensor t;
TF_TensorToTensor(tf_t, &t);
tensorflow::TensorShape shape = tensorflow::TensorShape({256});
tensorflow::PlatformGpuId platform_gpu_id(0);
tensorflow::GPUMemAllocator *sub_allocator =
new tensorflow::GPUMemAllocator(
tensorflow::GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
platform_gpu_id, false, {}, {});
tensorflow::GPUBFCAllocator *allocator =
new tensorflow::GPUBFCAllocator(sub_allocator, shape.num_elements() * sizeof(tensorflow::DT_UINT8), "GPU_0_bfc");
Tensor t_cuda = Tensor(allocator, tensorflow::DT_UINT8, shape);
cudaPointerAttributes attributes;
cudaError_t err = cudaPointerGetAttributes(&attributes, t_cuda.tensor_data().data());
if (err == cudaErrorInvalidValue)
return -2;
#if CUDART_VERSION >= 10000
return (attributes.type == cudaMemoryTypeDevice);
#else
return (attributes.memoryType == cudaMemoryTypeDevice);
#endif
}
*/
/*
JNIEXPORT int JNICALL Java_org_tensorflow_Tensor_elphelTestCUDAPointer(JNIEnv* env,
jclass clazz){
return 0x3;
}
*/
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif
/*
* Class: org_tensorflow_Tensor
* Method: allocate
* Signature: (I[JJ)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocate(JNIEnv *, jclass,
jint, jlongArray,
jlong);
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_elphelAllocateGPUTensor(JNIEnv *, jclass,
jlongArray, jint);
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_elphelGetGPUTensorPointer(JNIEnv*, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: allocateScalarBytes
* Signature: ([B)J
*/
JNIEXPORT jlong JNICALL
Java_org_tensorflow_Tensor_allocateScalarBytes(JNIEnv *, jclass, jbyteArray);
/*
* Class: org_tensorflow_Tensor
* Method: allocateNonScalarBytes
* Signature: ([J[Ljava/lang/Object;)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes(
JNIEnv *, jclass, jlongArray, jobjectArray);
/*
* Class: org_tensorflow_Tensor
* Method: delete
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_delete(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: buffer
* Signature: (J)Ljava/nio/ByteBuffer;
*/
JNIEXPORT jobject JNICALL Java_org_tensorflow_Tensor_buffer(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: dtype
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_dtype(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: shape
* Signature: (J)[J
*/
JNIEXPORT jlongArray JNICALL Java_org_tensorflow_Tensor_shape(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: setValue
* Signature: (JLjava/lang/Object;)V
*
* REQUIRES: The jobject's type and shape are compatible the with the DataType
* and shape of the Tensor referred to by the jlong handle.
*/
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_setValue(JNIEnv *, jclass,
jlong, jobject);
/*
* Class: org_tensorflow_Tensor
* Method: scalarFloat
* Signature: (J)F
*
*/
JNIEXPORT jfloat JNICALL Java_org_tensorflow_Tensor_scalarFloat(JNIEnv *,
jclass, jlong);
/*
* Class: org_tensorflow_Tensor
* Method: scalarDouble
* Signature: (J)D
*/
JNIEXPORT jdouble JNICALL Java_org_tensorflow_Tensor_scalarDouble(JNIEnv *,
jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: scalarInt
* Signature: (J)I
*/
JNIEXPORT jint JNICALL Java_org_tensorflow_Tensor_scalarInt(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: scalarLong
* Signature: (J)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_scalarLong(JNIEnv *, jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: scalarBoolean
* Signature: (J)Z
*/
JNIEXPORT jboolean JNICALL Java_org_tensorflow_Tensor_scalarBoolean(JNIEnv *,
jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: scalarBytes
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_Tensor_scalarBytes(JNIEnv *,
jclass,
jlong);
/*
* Class: org_tensorflow_Tensor
* Method: readNDArray
* Signature: (JLjava/lang/Object;)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_Tensor_readNDArray(JNIEnv *, jclass,
jlong, jobject);
/*
JNIEXPORT int JNICALL Java_org_tensorflow_Tensor_elphelIsCUDATensor(JNIEnv *,
jclass,
jlong);
*/
/*
JNIEXPORT int JNICALL Java_org_tensorflow_Tensor_elphelTestCUDAPointer(JNIEnv *,
jclass);
*/
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSOR_JNI_H_
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/java/src/main/native/tensorflow_jni.h"
#include <limits>
#include "tensorflow/c/c_api.h"
#include "tensorflow/java/src/main/native/exception_jni.h"
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv* env,
jclass clazz) {
return env->NewStringUTF(TF_Version());
}
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_elphelVersion(JNIEnv* env,
jclass clazz) {
return env->NewStringUTF("Elphel TensorFlow JNI call 1.0");
}
JNIEXPORT jbyteArray JNICALL
Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv* env, jclass clazz) {
TF_Buffer* buf = TF_GetAllOpList();
jint length = static_cast<int>(buf->length);
jbyteArray ret = env->NewByteArray(length);
env->SetByteArrayRegion(ret, 0, length, static_cast<const jbyte*>(buf->data));
TF_DeleteBuffer(buf);
return ret;
}
JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(
JNIEnv* env, jclass clazz, jstring filename) {
TF_Status* status = TF_NewStatus();
const char* cname = env->GetStringUTFChars(filename, nullptr);
TF_Library* h = TF_LoadLibrary(cname, status);
throwExceptionIfNotOK(env, status);
env->ReleaseStringUTFChars(filename, cname);
TF_DeleteStatus(status);
return reinterpret_cast<jlong>(h);
}
JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(
JNIEnv* env, jclass clazz, jlong handle) {
if (handle != 0) {
TF_DeleteLibraryHandle(reinterpret_cast<TF_Library*>(handle));
}
}
JNIEXPORT jbyteArray JNICALL Java_org_tensorflow_TensorFlow_libraryOpList(
JNIEnv* env, jclass clazz, jlong handle) {
TF_Buffer buf = TF_GetOpList(reinterpret_cast<TF_Library*>(handle));
if (buf.length > std::numeric_limits<jint>::max()) {
throwException(env, kIndexOutOfBoundsException,
"Serialized OpList is too large for a byte[] array");
return nullptr;
}
auto ret_len = static_cast<jint>(buf.length);
jbyteArray ret = env->NewByteArray(ret_len);
env->SetByteArrayRegion(ret, 0, ret_len, static_cast<const jbyte*>(buf.data));
return ret;
}
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
#define TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
#include <jni.h>
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
/*
* Class: org_tensorflow_TensorFlow
* Method: version
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_version(JNIEnv *,
jclass);
/*
* Class: org_tensorflow_TensorFlow
* Method: version2
* Signature: ()Ljava/lang/String;
*/
JNIEXPORT jstring JNICALL Java_org_tensorflow_TensorFlow_elphelVersion(JNIEnv *,
jclass);
/*
* Class: org_tensorflow_TensorFlow
* Method: registeredOpList
* Signature: ()[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_tensorflow_TensorFlow_registeredOpList(JNIEnv *, jclass);
/*
* Class: org_tensorflow_TensorFlow
* Method: libraryLoad
* Signature: (Ljava/lang/String;)J
*/
JNIEXPORT jlong JNICALL Java_org_tensorflow_TensorFlow_libraryLoad(JNIEnv *,
jclass,
jstring);
/*
* Class: org_tensorflow_TensorFlow
* Method: libraryDelete
* Signature: (J)V
*/
JNIEXPORT void JNICALL Java_org_tensorflow_TensorFlow_libraryDelete(JNIEnv *,
jclass,
jlong);
/*
* Class: org_tensorflow_TensorFlow
* Method: libraryOpList
* Signature: (J)[B
*/
JNIEXPORT jbyteArray JNICALL
Java_org_tensorflow_TensorFlow_libraryOpList(JNIEnv *, jclass, jlong);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif // TENSORFLOW_JAVA_SRC_MAIN_NATIVE_TENSORFLOW_JNI_H_
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment