快乐学习
前程无忧、中华英才非你莫属!

TensorFlow Java版本入门

简介
TensorFlow™ 是一个使用数据流图进行数值计算的开源软件库。图中的节点代表数学运算, 而图中的边则代表在这些节点之间传递的多维数组(Tensors)。这种灵活的架构可让您使用一个 API 将计算工作部署到桌面设备、服务器或者移动设备中的一个或多个 CPU 或 GPU。 TensorFlow 最初是由 Google 机器智能研究部门的 Google Brain 团队中的研究人员和工程师开发的,用于进行机器学习和深度神经网络研究, 但它是一个非常基础的系统,因此也可以应用于众多其他领域。
tensorflow 是 google 开源的机器学习工具,在2015年11月其实现正式开源,开源协议Apache 2.0

一、构造环境
Java版本1.7  MAVEN3
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>tensorflow</artifactId>
    <version>1.5.0</version>
</dependency>


<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>libtensorflow</artifactId>
    <version>1.5.0</version>
</dependency>
<dependency>
    <groupId>org.tensorflow</groupId>
    <artifactId>libtensorflow_jni_gpu</artifactId>
    <version>1.5.0</version>
</dependency>

HelloWorld
import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;

public class HelloTF {
    public static void main(String[] args) throws Exception {
        try (Graph g = new Graph()) {
            final String value = "Hello from " + TensorFlow.version();
            // 用一个单一的操作,一个常量来构造计算图
            // 名为“MyConst”,值为“value”。
            try (Tensor t = Tensor.create(value.getBytes("UTF-8"))) {
                //Java API还不包括添加操作的便利功能。
                g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
            }
            // 在会话中执行“MyConst”操作。
            try (Session s = new Session(g);
                 Tensor output = s.runner().fetch("MyConst").run().get(0)) {
                System.out.println(new String(output.bytesValue(), "UTF-8"));
            }
        }
    }
}


二、整体流程:
  • 图形构建:使用OperationBuilder类构造一个图形来解码,调整大小和规格化JPEG图像。
  • 模型加载:使用Graph.importGraphDef()加载预先训练的Inception模型。
  • 图形执行:使用会话来执行图形并找到图像的最佳标签。
小编个人见解:
大多数网友,都是以Python的一些工具类,来长时间的训练模型,最终生成PB文件,经通过训练好的模型,经过构建、加载、执行、速度也是飞快,只是模型训练,以及利用TensorFlow构造自己的AI机器人,过程还是很复杂,首先你得需要熟练、熟练、熟练,掌握Python!其次还需要掌握点机器学习的基本理论、线性代数以及神经网络的知识  ~
备注:因为TensorFlow的JAVA API,刚出来,demo少啊~,然后java又没有Numpy、pandas、比较热门的数据处理函数库。
就连斯坦福大学出的课程里也是用Python讲的TensorFlow,OMGD,可想而知!
因最近Google开放了Java API,瞬间JAVAEE的应用乐开了花,只要数据够多,模型训练的够好,直接拿过来就用,根本不成问题!瞬间可以秒杀BAT的云服务,如果数据不够多,模型训练的不够好,建议还是使用Bat,因为他们的后台那可是海量数据,后期的模型精准度,一定是很棒的!API也相对简洁!省去了不少企业对AI应用的部署时间!
各位小伙伴的建议,根据每家公司所掌握的数据特质,去选相对应的AI服务
以下是小编对Tensor官网的JAVA的初期翻译!因Python版本的API,国内已经有中文社区,直接看就可以le !官方的TensorFlow有的地区,还得xx!
http://www.tensorfly.cn/tfdoc/api_docs/index.html

三、数据类型
public final枚举 DataType
将a中的元素类型表示Tensor为枚举。

继承的方法

▸从类java.lang.Enum
▸从类java.lang.Object
▸从接口java.lang.Comparable

枚举值

public static final DataType BOOL

public static final DataType DOUBLE

public static final DataType FLOAT

public static final DataType INT32

public static final DataType INT64    

public static final DataType STRING

public static final DataType UINT8


四、核心类和接口
定义要构建,保存,加载和执行TensorFlow模型的类。
警告:API目前是实验性的,并且不包含TensorFlow API稳定性保证。 有关安装说明,请参阅README.md。
LabelImage示例演示如何使用此API来使用预先训练的Inception架构卷积神经网络对图像进行分类。 
它表明:
*图构造:使用OperationBuilder类来构建一个图解码,调整和规格化一个JPEG图像。
*模型加载:使用Graph.importGraphDef()加载预先训练的Inception模型。
*图形执行:使用会话执行图形并找到图像的最佳标签。

接口

Operand<T>
由TensorFlow操作的操作数实现的接口。

Classes

表示TensorFlow计算的数据流图。
在张量上执行计算的图形节点。
图表中的操作的构建器。
Output<T>
由操作产生的张量的象征性句柄。
SavedModelBundle表示从存储装载的模型。
图形执行的驱动程序。
输出张量和执行会话时获得的元数据。
运行操作并评估张量。
通过手术产生的可能部分已知的张量形状。
Tensor<T>
一个静态类型的多维数组,其元素是由T描述的类型
描述TensorFlow运行时的静态实用程序方法。
用于创建张量对象的类型安全工厂方法。

枚举类型

数据类型:将张量中元素的类型表示为枚举。


异常信息
TensorFlowException
执行TensorFlow图形时引发未经检查的异常。

五、Graph
表示TensorFlow计算的数据流图(数据传递和加工角度,以图形方式来表达系统的逻辑功能)。
图的实例是线程安全的
必须通过调用close()方法显式释放Graph对象消耗的资源,然后不再需要Graph对象。

公开的构造器

Graph()
创建一个空的构造方法

公开的方法

void
close()
释放与图形关联的资源。

void
importGraphDef(byte[] graphDef, String prefix)
graphDef:TensorFlow图形的序列化表示
prefix:一个前缀将被添加到graphDef中的名称

导入TensorFlow图形的序列化表示。.
如果graphDef不是一个公认的图形序列化,抛出:IllegalArgumentException

void
importGraphDef(byte[] graphDef)
导入TensorFlow图形的序列化表示.

opBuilder(String type, String name)
type:操作(即,确定要执行的计算)
name:参考图中创建的操作。
返回构建器将操作添加到图形。(当build()调用的时候会把操作添加到图中。如果build()没有被调用,那么一些资源可能会泄漏。)

operation(String name)
用提供的名称返回操作(图中的节点)。
或者null如果Graph中不存在这样的操作。

Iterator<Operation>
迭代器遍历图中的所有Operations。
迭代次序是未指定的。迭代器的消费者将不会收到任何通知,如果在迭代过程中底层图更改

byte[]
生成图形的序列化表示。

继承的方法

▸从类java.lang.Object
▸从接口java.lang.AutoCloseable

六、Operand:操作数*(接口)。
public interface Operand
已知的间接子类:Output<T>

由TensorFlow操作的操作数实现的接口。
用法示例:

// “decodeJpeg”操作可以用作“cast”操作的操作数
Operand<UInt8> decodeJpeg = ops.image().decodeJpeg(…);
ops.math().cast(decodeJpeg, DataType.FLOAT);
// “唯一”操作的输出“y”可以用作“强制”操作的操作数
Output<Integer> y = ops.array().unique(…).y();
ops.math().cast(y, Float.class);
// “split”操作可以用作“concat”操作的操作数列表
Iterable<? extends Operand<Float>> split = ops.array().split(…);
ops.array().concat(0, split);

公开的方法

abstractOutput<T> asOutput()  返回tensor的符号句柄。

public abstract Output<T> asOutput ()

返回tensor的符号句柄。
TensorFlow操作的输入是另一个TensorFlow操作的输出。 此方法用于获取表示输入计算的符号句柄。

See Also

public OperationBuilder addInput (Output<?> input)

返回构建器以创建操作。
TensorFlow操作的输入是另一个TensorFlow操作的输出。 此方法用于将输入添加到OperationBuilder。

参数
input
Output 
应该是OperationBuilder的输入。

返回值
  • 用于链接的OperationBuilder实例。

七、Operation
public final class Operation

在Tensors上执行计算的Graph节点。
一个操作是一个节点,Graph它需要零个或多个Tensors(由Graph中的其他操作生成)作为输入,并产生零个或多个Tensors作为输出。操作实例只有在它们所属的图形有效时才有效。因此,如果 close()已被调用,则Operation实例上的方法可能会失败,并显示一个 IllegalStateException。
操作实例是不可变的,也是线程安全的。

公共方法

boolean
equals(Object o)
INT
hashCode()
INT
inputListLength(String name)
name:tensors 列表的标识符(其中可能有很多)输入到这个操作。
返回此操作的Tensors 的给定输入列表的大小。
一个操作有多个命名输入,每个输入都包含一个tensors或一张tensors。此方法返回操作的特定命名输入的tensors列表的大小。

String
name()
返回操作的全名。

INT
numOutputs()
返回此操作产生的tensors 的数量。

<T>Output<T>
output(int idx)
返回由此操作产生的张量之一的符号句柄。

outputList(int idx,int length)
将符号句柄返回到由此操作产生的张量列表。

INT
outputListLength(String name)
返回此操作生成的张量列表的大小。

String
toString()
String
type()
返回操作的类型,即操作执行的计算的名称。

  

继承的方法

▸从类java.lang.Object

公共方法

public boolean equals (Object o)

参数   o

public int hashCode ()

public int inputListLength (String name)

返回此操作的Tensors的给定输入列表的大小。
一个操作有多个命名输入,每个输入都包含一个tensors或一张tensors。此方法返回操作的特定命名输入的tensors列表的大小。

参数
name
tensors 列表的标识符(其中可能有很多)输入到这个操作。

返回
  • 由这个命名输入生成的tensors 列表的大小。
抛出
抛出:IllegalArgumentException
如果这个操作没有输入提供的名字。

public String name ()

返回操作的全名。

public int numOutputs ()

返回此操作产生的tensors 的数量。

public Output <T> output (int idx)

返回由此操作产生的tensors 之一的符号句柄。
警告:不检查tensors的类型是否与T相匹配。建议使用明确的类型参数调用此方法,而不是推断它,例如 operation.<Integer>output(0)

参数
IDX
此操作产生的输出之间的输出索引。

public Output [] <?> outputList (int idx,int length)

将符号句柄返回到由此操作产生的张量列表。

参数
IDX
列表的第一个张量的索引
length
列表中张量的数量

返回
  • 数组 Output

public int outputListLength (String name)

返回此操作生成的张量列表的大小。
一个操作有多个命名输出,每个输出产生一个张量或一张张量。此方法返回操作的特定命名输出的张量列表的大小。

参数
名称
这个操作产生的张量列表(可能有很多)的标识符。

返回
  • 这个命名输出生成的张量列表的大小。
抛出
抛出:IllegalArgumentException
如果此操作没有提供名称的输出。

public String toString ()

public String type ()

返回操作的类型,即操作执行的计算的名称。


八、OperationBuilder

public final class OperationBuilder
图表中的操作的构建器。
OperationBuilder的实例不是线程安全的。
用于将操作添加到图形的构建器。 例如,以下内容使用构建器创建一个产生常量“3”的操作作为其输出:

// g is a Graph instance.
try (Tensor c1 = Tensor.create(3.0f)) {
   g.opBuilder(“Const”, “MyConst”)
       .setAttr(“dtype”, c1.dataType())
       .setAttr(“value”, c1)
       .build();
}

Public Methods

确保操作在控制操作完成之前不会执行。
addInput(Output<?> input)
返回构建器以创建操作。
addInputList(Output[]<?> inputs)
build()
将正在构建的操作添加到图形。
setAttr(String name, Tensor<?> value)
setAttr(String name, Tensor[]<?> value)
setAttr(String name, String[] value)
setAttr(String name, boolean[] value)
setAttr(String name, DataType[] value)
setAttr(String name, String value)
setAttr(String name, float[] value)
setAttr(String name, long value)
setAttr(String name, long[] value)
setAttr(String name, boolean value)
setAttr(String name, Shape[] value)
setAttr(String name, float value)
setAttr(String name, DataType value)
setAttr(String name, byte[] value)
setAttr(String name, Shape value)
setDevice(String device)

From class java.lang.ObjectInherited Methods

公开的方法

public OperationBuilder addControlInput (Operation control)

确保操作在控制操作完成之前不会执行。
控制输入是在运行当前正在构建的操作之前必须执行的操作。
例如,可以添加一个Assert操作作为此操作的控制输入。 Assert现在作为一个前提条件,在运行操作之前总是会验证自己。

参数
control
运行此操作之前必须执行的操作。

返回值:用于链接的OperationBuilder实例。

public OperationBuilder addInput (Output<?> input)

返回构建器以创建操作。
TensorFlow操作的输入是另一个TensorFlow操作的输出。 此方法用于将输入添加到OperationBuilder。

参数
输入
输出应该是OperationBuilder的输入。
参数
  • 用于链接的OperationBuilder实例。

public OperationBuilder addInputList (Output[]<?> inputs)

参数 :
inputs

public Operation build ()

将正在构建的操作添加到图形

Operation()返回后,OperationBuilder不可用。

上次更新日期:一月 27, 2018

九、Output

public final class Output
由操作产生的tensor 的符号句柄。
输出<T>是一个Tensor <T>的符号句柄。 通过执行会话中的操作来计算张量的值。
通过实现操作数接口,这个类的实例也作为ERROR(Op / org.tensorflow.op.Op Op)实例的操作数。

Public Constructors

Output(Operation op, int idx)
处理Operation op的idx-th输出。
公开的方法
Output<T>
返回tensor的符号句柄。
返回由此输出引用的tensor的数据类型。
boolean
equals(Object o)
int
int
index()
返回操作输出的索引。
op()
返回将生成此输出引用的tensor的操作。
shape()
返回由此输出引用的tensor(可能部分已知)的形状。
String

Inherited Methods

From class java.lang.Object
From interface org.tensorflow.Operand

Public Constructors

public Output (Operation op, int idx)

处理Operation op的idx-th输出。

Parameters
op
idx

Public Methods

public Output<T> asOutput ()

返回tensor的符号句柄。
TensorFlow操作的输入是另一个TensorFlow操作的输出。 此方法用于获取表示输入计算的符号句柄。

public DataType dataType ()

返回由此输出引用的张量的数据类型。

public int index ()

返回操作输出的索引。

public Operation op ()

返回将生成此输出引用的张量的操作。

public Shape shape ()

返回由此输出引用的张量(可能部分已知)的形状。

public String toString ()


十、SavedModelBundle

public class SavedModelBundle
SavedModelBundle表示从存储加载的模型。
该模型由计算的描述(一个Graph),一个具有Tesnor的Session(例如图中的参数或变量)初始化为存储在存储器中的值,以及模型的描述(MetaGraphDef协议缓冲区的序列化表示)。

公开的方法

void
close()
释放与保存的模型包关联的资源(图表和会话)。
graph()
返回描述模型执行的计算的图形。
load(String exportDir, String… tags)
从导出目录加载保存的模型。
byte[]
返回与保存模型关联的序列化MetaGraphDef协议缓冲区。
返回使用模型执行计算的Session。

继承的方法

From class java.lang.Object
From interface java.lang.AutoCloseable

Public Methods

public void close ()

释放与保存的模型包关联的资源(图表和会话)。

public Graph graph ()

返回描述模型执行的计算的图形。

public static SavedModelBundle load (String exportDir, String… tags)

从导出目录加载保存的模型。 正在加载的模型应该使用Saved Model API创建。

Parameters
exportDir
  包含已保存模型的目录路径
tags
标识要加载的特定metagraphdef的标签.
Returns
  • 一个包含图形和相关会话的包。

public byte[] metaGraphDef ()

返回与保存模型关联的序列化MetaGraphDef协议缓冲区。

public Session session ()

返回使用模型执行计算的Session。

Returns
  • 会议开始
上次更新日期:十一月 2, 2017

十一、Session

public final class Session
图形执行的驱动程序。
Session实例封装了执行Graph中的Operations以计算Tensors的环境。 例如:

// Let's say graph is an instance of the Graph class
// for the computation y = 3 * x

try (Session s = new Session(graph)) {
   try (Tensor x = Tensor.create(2.0f);
       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
       System.out.println(y.floatValue());  // Will print 6.0f
   }
   try (Tensor x = Tensor.create(1.1f);
       Tensor y = s.runner().feed("x", x).fetch("y").run().get(0)) {
       System.out.println(y.floatValue());  // Will print 3.3f
   }
}

警告:会话拥有必须通过调用close()明确释放的资源。
会话的实例是线程安全的。

嵌套类

class 
输出Tensors和执行会话时获得的元数据。
class
Run    操作和评估Tensors

公开构造器

与关联的图构建一个新的会话。
Session(Graph g, byte[] config)
使用关联的图形和配置选项构建新的会话。
公开的方法
void
close()
释放与会话相关的资源。
runner()
创建一个Runner来执行图表操作并评估Tensors。

继承的方法

From class java.lang.Object
From interface java.lang.AutoCloseable

公开的构造方法

public Session (Graph g)

与关联的图构建一个新的会话。

Parameters
g

public Session (Graph g, byte[] config)

使用关联的图形和配置选项构建新的会话。

Parameters
g
创建的会话将运行的图形
config
会话的配置参数指定为序列化的ConfigProto协议缓冲区。
Throws

抛出:IllegalArgumentException

如果配置不是ConfigProto协议缓冲区的有效序列化。


公开的方法

public void close ()

释放与会话相关的资源。
阻塞直到没有活动的执行
(run()调用)。 会话在密切返回后不可用。

public Session.Runner runner ()

创建一个Runner来执行图表操作并评估Tensors。

十二、Session.Run

public static final class Session.Run
输出Tensors和执行会话时获得的元数据。

字段

public byte[] metadata (Experimental): Metadata about the run.
public List<Tensor<?outputs>>  Tensors from requested fetches.

公开的构造器

继承的方法

From class java.lang.Object

Fields

public byte[] metadata (Experimental): Metadata about the run.

一个序列化的RunMetadata协议缓冲区。 org.tensorflow软件包没有任何协议缓冲区依赖关系,以便与资源受限的系统保持友好关系(类似nanoproot的东西可能更合适)。 这个成本是不透明的。 这个选择正在审查中,这个字段可能随时被更多类型安全的等价物取代。

public List<Tensor<?>> outputs

请求提取的Tensors。

十三、Session.Runner

public final class Session.Runner
Run Operations and evaluate Tensors.
Runner运行必要的图片片断来执行评估张量以获取所需的每个操作。 Feed(String,int,Tensor)调用允许调用者通过将提供的Tensors替换为提供给Feed的操作(String,int,Tensor)的输出来覆盖图中Tensors的值。

公开的构造器

公开的方法

addTarget(String operation)
使run()执行操作,但不返回任何评估的Tensors。
addTarget(Operation operation)
使run()执行操作,但不返回任何评估的Tensors。
feed(String operation, Tensor<?> t)
避免评估操作并用t代替它产生的值。
feed(String operation, int index, Tensor<?> t)
避免通过用t代替产生的值来评估操作的索引 – 输出。
feed(Output<?> o, Tensor<?> t)
使用t代替通过执行输出引用的操作所引用的张量。
fetch(String operation)
使run()返回操作的输出。
fetch(String operation, int index)
使run()返回操作的index-th输出。
fetch(Output<?> output)
让run()返回tensors被输出。
List<Tensor<?>>
run()
执行计算所有请求获取所需的图形片段。
执行图片段来计算请求的获取和返回关于运行的元数据。
setOptions(byte[] options)
(Experimental method): 为此运行设置选项(通常用于调试)。

继承的方法

From class java.lang.Object


公开的构造器

public Session.Runner ()

公开的方法

public Session.Runner addTarget (String operation)

使run()执行操作,但不返回任何评估的Tensors。

Parameters
operation

public Session.Runner addTarget (Operation operation)

使run()执行操作,但不返回任何评估的Tensors。

Parameters
operation

public Session.Runner feed (String operation, Tensor<?> t)

避免评估操作并用t代替它产生的值。

Parameters
operation
是操作的字符串名称,在这种情况下,此方法是供稿的简写(操作,0),
或者它是一个形式为operation_name:output_index的字符串,在这种情况下,此方法的行为与feed(operation_name,output_index)相似。 这些以冒号分隔的名称通常用于包含在metaGraphDef()中的SignatureDef协议缓冲区消息。
T

public Session.Runner feed (String operation, int index, Tensor<?> t)

避免通过用t代替产生的值来评估操作的索引 – 输出。
图表中的操作可以有多个输出,索引标识哪一个是提供的。

Parameters
operation
index
t

public Session.Runner feed (Output<?> o, Tensor<?> t)

使用t代替通过执行输出引用的操作所引用的张量。

Parameters
o
t

public Session.Runner fetch (String operation)

Make run() return the output of operation.

Parameters

operation

是操作的字符串名称,在这种情况下,此方法是fetch(操作,0)的简写形式,或者是形式为operation_name:output_index的字符串,在这种情况下,此方法的作用与fetch(operation_name,output_index)。 这些以冒号分隔的名称通常用于包含在metaGraphDef()中的SignatureDef协议缓冲区消息。

使run()返回操作的index-th输出。

 public Session.Runner fetch (String operation, int index)

图中的操作可以有多个输出,索引标识哪一个返回。

Parameters
operation
index

public Session.Runner fetch (Output<?> output)

使run()返回输出引用的张量。

Parameters
output

public List<Tensor<?>> run ()

执行计算所有请求提取所需的图片段。
警告:调用者承担所有返回的张量的所有权,即,调用者必须在返回列表的所有元素上调用close()来释放资源。
TODO(ashankar):在这里重新考虑返回类型。 特别是两件事:(a)让调用者更容易清理(可能返回类似于SessionTest.java中的AutoCloseableList),并且(b)评估返回值是否应该是列表,或者是Map <Output,Tensor>?
TODO(andrewmyers):如果这里返回的内容更容易以类型安全的方式提取输出张量,那也是一件好事。

public Session.Run runAndFetchMetadata ()

执行图片段以计算请求的提取并返回有关运行的元数据。
这与run()完全相同,但除了请求的Tensors之外,还以序列化RunMetadata协议缓冲区的形式返回有关图执行的元数据。

public Session.Runner setOptions (byte[] options)

(实验方法):为此运行设置选项(通常用于调试)。
这些选项以序列化的RunOptions协议缓冲区的形式呈现。
org.tensorflow软件包没有任何协议缓冲区依赖关系,以便与资源受限的系统保持友好关系(类似nanoproot的东西可能更合适)。 这个代价就是这个API函数缺乏类型安全性。 这一选择正在审查之中,并且此功能可能随时被更多的类型安全等价物取代。

Parameters
options

十四、Shape

public final class Shape
手术产生的tensors的可能部分已知的形状。

Public Methods

boolean
equals(Object obj)
int
static Shape
make(long firstDimensionSize, long… otherDimensionSizes)
创建一个表示N维值的Shape。
int
这个形状代表的维数。
static Shape
scalar()
创建一个表示标量值的Shape。
long
size(int i)
第i维的大小。
String
简要描述用于调试的形状。
static Shape
创建一个表示未知数量的维度的形状.

继承的方法

From class java.lang.Object


公开的方法

public boolean equals (Object obj)

Parameters
obj

public int hashCode ()

public static Shape make (long firstDimensionSize, long… otherDimensionSizes)

创建一个表示N维值的Shape。
创建一个表示N维值(N至少为1)的Shape,并为每个维度提供所提供的大小。 -1表示相应维度的大小未知。 例如:

// A 2-element vector.
Shape vector = Shape.create(2);

// A 2x3 matrix.
Shape matrix = Shape.create(2, 3);

// A matrix with 4 columns but an unknown number of rows.
// This is typically used to indicate the shape of tensors that represent
// a variable-sized batch of values. The Shape below might represent a
// variable-sized batch of 4-element vectors.
Shape batch = Shape.create(-1, 4);

Parameters
firstDimensionSize
otherDimensionSizes

public int numDimensions ()

这个形状代表的维数。

Returns
  • 如果维数不确定,则为-1;如果形状表示标量,则为0;向量为1;矩阵等为2。

public static Shape scalar ()

创建一个表示标量值的Shape。

public long size (int i)

第i维的大小。

Parameters
i
Returns
  • 所请求维度的大小,如果未知,则为-1。

public String toString ()

简要描述用于调试的形状。

public static Shape unknown ()

创建一个表示未知数量的维度的形状。

上次更新日期:一月 27, 2018

十五、Tensor

public final class Tensor
一个静态类型的多维数组,其元素是由T描述的类型
张量的实例不是线程安全的。
警告:当不再需要对象时,必须通过调用close()方法显式释放张量对象消耗的资源。 例如,使用try-with-resources块:

try (Tensor t = Tensor.create(…)) {
   doSomethingWith(t);
}

Public Methods

boolean
以标量布尔Tensor返回值。
byte[]
返回标量字符串tensor中的值。
void
close()
释放与Tensor相关的资源。
<U> U
copyTo(U dst)
将张量的内容复制到dst并返回dst。
static Tensor<?>
create(Object obj)
从检查类的对象中创建张量,以确定底层数据类型应该是什么。
static <T> Tensor<T>
create(Class<T> type, long[] shape, ByteBuffer data)
用给定缓冲区的数据创建任何类型的张量。
static Tensor<Double>
create(long[] shape, DoubleBuffer data)
用给定缓冲区的数据创建一个双Tensor。
static Tensor<Long>
create(long[] shape, LongBuffer data)
用给定缓冲区的数据创建一个长Tensor。
static Tensor<Integer>
create(long[] shape, IntBuffer data)
用给定缓冲区的数据创建一个整数tensor。
static Tensor<Float>
create(long[] shape, FloatBuffer data)
用给定缓冲区的数据创建一个浮点Tensor。
static <T> Tensor<T>
create(Object obj, Class<T> type)
从Java对象创建一个Tensor。
返回存储在Tensor中的元素的DataType。
double
以标量Double Tensor返回值。
<U> Tensor<U>
expect(Class<U> type)
使用类型Tensor <U>返回此张量对象。
float
返回标量Float张量中的值。
int
返回标量整数张量中的值。
long
返回标量Long张量的值。
int
返回张量数据的大小(以字节为单位)。
int
返回张量的维数(有时称为等级)。
int
返回张量的展开(1-D)视图中的元素数量。
long[]
shape()
返回张量的形状,即每个尺寸的大小。
String
返回描述张量的类型和形状的字符串。
void
writeTo(LongBuffer dst)
将Long张量的数据写入给定的缓冲区。
void
writeTo(DoubleBuffer dst)
将Double张量的数据写入给定的缓冲区。
void
writeTo(IntBuffer dst)
将整数张量的数据写入给定的缓冲区。
void
writeTo(ByteBuffer dst)
将张量数据写入给定的缓冲区。
void
writeTo(FloatBuffer dst)
将浮点张量的数据写入给定的缓冲区。


public U copyTo (U dst)

将张量的内容复制到dst并返回dst。
对于非标量张量,此方法将底层张量的内容复制到Java数组中。 对于标量张量,请改为使用bytesValue(),floatValue(),doubleValue(),intValue(),longValue()或booleanValue()中的一个。 dst的类型和形状必须与张量兼容。 例如:

int matrix[2][2] = { {1,2},{3,4} };
try(Tensor t = Tensor.create(matrix)) {
// Succeeds and prints “3”
int[][] copy = new int[2][2];
System.out.println(t.copyTo(copy)[1][0]);
// Throws IllegalArgumentException since the shape of dst does not match the shape of t.
int[][] dst = new int[4][1];
t.copyTo(dst);
}


public static Tensor<?> create (Object obj)
从检查类的对象中创建张量,以确定底层数据类型应该是什么。

Parameters
obj
Throws
IllegalArgumentException
:与系统不兼容

public static Tensor<T> create (Class<T> type, long[] shape, ByteBuffer data)

用给定缓冲区的数据创建任何类型的张量。
根据TensorFlow C API的规范将张量数据编码为数据,并创建一个张量,其中包含任何类型的形状。

Parameters
类型
张量元素类型,表示为类对象。
形状
张量形状。
数据
包含张量数据的缓冲区。
抛出:
抛出:IllegalArgumentException
如果张量数据类型或形状与缓冲区不兼容

public static Tensor<Double> create (long[] shape, DoubleBuffer data)

用给定缓冲区的数据创建一个双张量。
通过将缓冲区中的元素(从当前位置开始)复制到张量中,创建具有给定形状的张量。 例如,如果shape = {2,3}(它表示一个2×3矩阵),那么缓冲区必须剩下6个元素,这个方法将会消耗这个元素。

Parameters
shape
 the tensor shape.
data
 a buffer containing the tensor data.
Throws

抛出:IllegalArgumentException

如果张量形状与缓冲区不兼容

public static Tensor<Long> create (long[] shape, LongBuffer data)

用给定缓冲区的数据创建一个长张量。
通过将缓冲区中的元素(从当前位置开始)复制到张量中,创建具有给定形状的张量。 例如,如果shape = {2,3}(它表示一个2×3矩阵),那么缓冲区必须剩下6个元素,这个方法将会消耗这个元素。

Parameters
shape
the tensor shape.
data
a buffer containing the tensor data.
Throws

抛出:IllegalArgumentException

如果张量形状与缓冲区不兼容


public static Tensor<Integer> create (long[] shape, IntBuffer data)

用给定缓冲区的数据创建一个整数张量。
通过将缓冲区中的元素(从当前位置开始)复制到张量中,创建具有给定形状的张量。 例如,如果shape = {2,3}(它表示一个2×3矩阵),那么缓冲区必须剩下6个元素,这个方法将会消耗这个元素。

Parameters
shape
the tensor shape.
data
a buffer containing the tensor data.
Throws
IllegalArgumentException
If the tensor shape is not compatible with the buffer

public static Tensor<Float> create (long[] shape, FloatBuffer data)

用给定缓冲区的数据创建一个浮点张量。
通过将缓冲区中的元素(从当前位置开始)复制到张量中,创建具有给定形状的张量。 例如,如果shape = {2,3}(它表示一个2×3矩阵),那么缓冲区必须剩下6个元素,这个方法将会消耗这个元素。

Parameters
shape
the tensor shape.
data
a buffer containing the tensor data.
Throws
IllegalArgumentException
If the tensor shape is not compatible with the buffer

public static Tensor<T> create (Object obj, Class<T> type)

从Java对象创建一个张量。
张量是一组有限类型元素的多维数组。 并非所有Java对象都可以转换为张量。 特别是,参数obj必须是原始的(float,double,int,long,boolean,byte)或者其中一个基元的多维数组。 参数类型指定如何将第一个参数解释为TensorFlow类型。 例如:

// Valid: A 64-bit integer scalar.
Tensor<Long> s = Tensor.create(42L, Long.class);

// Valid: A 3x2 matrix of floats.
float[][] matrix = new float[3][2];
Tensor<Float> m = Tensor.create(matrix, Float.class);

// Invalid: Will throw an IllegalArgumentException as an arbitrary Object
// does not fit into the TensorFlow type system.
Tensor<?> o = Tensor.create(new Object())

// Invalid: Will throw an IllegalArgumentException since there are
// a differing number of elements in each row of this 2-D array.
int[][] twoD = new int[2][];
twoD[0] = new int[1];
twoD[1] = new int[2];
Tensor<Integer> x = Tensor.create(twoD, Integer.class);
String-typed Tensors are multi-dimensional arrays of arbitrary byte sequences, so can be initialized from arrays of byte[] elements. For example:
// Valid: A String tensor.
Tensor<String> s = Tensor.create(new byte[]{1, 2, 3}, String.class);

// Java Strings will need to be encoded into a byte-sequence.
String mystring = "foo";
Tensor<String> s = Tensor.create(mystring.getBytes("UTF-8"), String.class);

// Valid: Matrix of String tensors.
// Each element might have a different length.
byte[][][] matrix = new byte[2][2][];
matrix[0][0] = "this".getBytes("UTF-8");
matrix[0][1] = "is".getBytes("UTF-8");
matrix[1][0] = "a".getBytes("UTF-8");
matrix[1][1] = "matrix".getBytes("UTF-8");
Tensor<String> m = Tensor.create(matrix, String.class);

Parameters
obj
要转换为张量<T>的对象。 请注意,它是否与类型T兼容不被类型系统检查。 为了类型安全地创建张量,请使用张量。
type
表示类型T的类对象
Throws
IllegalArgumentException
if obj is not compatible with the TensorFlow type system.

public DataType dataType ()

返回存储在Tensor中的元素的DataType。

public double doubleValue ()

以标量Double张量返回值。

Throws
IllegalArgumentException
if the Tensor does not represent a double scalar.

public Tensor<U> expect (Class<U> type)

使用类型Tensor <U>返回此张量对象。 当给定一个Tensor <?>类型的值时,这个方法很有用。

Parameters

type
任何(非null)正确类型的数组。
Throws
IllegalArgumentException
if the actual data type of this object does not match the type U.

public float floatValue ()

返回标量Float张量中的值。

Throws
IllegalArgumentException
if the Tensor does not represent a float scalar.

public int intValue ()

返回标量整数张量中的值。

Throws
IllegalArgumentException
if the Tensor does not represent a int scalar.

public long longValue ()

返回标量Long张量的值。

Throws
IllegalArgumentException
if the Tensor does not represent a long scalar.

public int numBytes ()

返回张量数据的大小(以字节为单位)。

public int numDimensions ()

返回张量的维数(有时称为等级)。
标量为0,矢量为1,矩阵为2,三维张量为3等。

public int numElements ()

返回张量的展开(1-D)视图中的元素数量。

public long[] shape ()

返回张量的形状,即每个尺寸的大小。

Returns
  • 第i个元素是张量的第i维的大小的数组。

public String toString ()

返回描述张量的类型和形状的字符串。

public void writeTo (LongBuffer dst)

将Long张量的数据写入给定的缓冲区。
将numElements()元素复制到缓冲区。

Parameters
DST
目标缓冲区Throws

BufferOverflowException

如果给定缓冲区中的空间不足以用于此张量中的数据
抛出:IllegalArgumentException

如果tensor数据类型不是Long

将Double张量的数据写入给定的buffer.public void writeTo(DoubleBuffer dst)

将numElements()元素复制到缓冲区。

Parameters
DST
目标缓冲区

Throws


BufferOverflowException

如果给定缓冲区中的空间不足以用于此张量中的数据
抛出:IllegalArgumentException

如果张量数据类型不是Double

将整数张量的数据写入给定的buffer.public void writeTo(IntBuffer dst)

将numElements()元素复制到缓冲区。

Parameters
DST
目标缓冲区

Throws


BufferOverflowException
如果给定缓冲区中的空间不足以用于此张量中的数据
抛出:IllegalArgumentException


如果张量数据类型不是整数

将张量数据写入给定的buffer.public void writeTo(ByteBuffer dst)

对于基元类型,以本地字节顺序将numBytes()字节复制到缓冲区。

Parameters
dst
  目标缓冲区
Throws

BufferOverflowException

如果给定缓冲区中的空间不足以用于此张量中的数据


public void writeTo (FloatBuffer dst)

将浮点张量的数据写入给定的缓冲区。
将numElements()元素复制到缓冲区。

Parameters
dst
目标缓冲区
Throws
BufferOverflowException
如果给定缓冲区中的空间不足以用于此张量中的数据
抛出:IllegalArgumentException
如果张量数据类型不是Float

十六、TensorFlow

public final class TensorFlow
描述TensorFlow运行时的静态实用程序方法。

Public Methods

static byte[]
loadLibrary(String filename)
将动态库加载到文件名中,并注册该库中存在的操作和内核。
static byte[]
所有的TensorFlow操作都可以在这个地址空间中使用。
static String 
返回底层TensorFlow运行时的版本。

Inherited Methods

From class java.lang.Object

Public Methods

public static byte[] loadLibrary (String filename)

将动态库加载到文件名中,并注册该库中存在的操作和内核。

Parameters
文件名
包含要加载的操作和内核的动态库的路径。

Returns
OpList协议缓冲区消息的序列化字节定义库中定义的操作。
Throws
UnsatisfiedLinkError
if filename cannot be loaded.

public static byte[] registeredOpList ()

所有的TensorFlow操作都可以在这个地址空间中使用。

Returns
  • OpList协议缓冲区的序列化表示,其中列出了所有可用的TensorFlow操作。

public static String version ()

返回底层TensorFlow运行时的版本。

TensorFlowException

public final class TensorFlowException
Unchecked exception thrown when executing TensorFlow Graphs.

Inherited Methods

From class java.lang.Throwable
From class java.lang.Object

十七、Tensors

public final class Tensors
用于创建张量对象的类型安全工厂方法。

Public Methods

static Tensor<Float>
create(float[][][] data)
创建浮点元素的三级张量。
static Tensor<Double>
create(double[] data)
创建双元素的1级张量。
static Tensor<Boolean>
create(boolean[][][][][] data)
创建布尔元素的秩-5张量。
static Tensor<String>
create(byte[][] data)
创建一个字节元素的秩1张量。
static Tensor<Long>
create(long[] data)
创建长元素的一级张量。
static Tensor<Double>
create(double data)
创建包含单个双元素的标量张量。
static Tensor<Integer>
create(int[][][][][] data)
创建一个int元素的秩-5张量。
static Tensor<Integer>
create(int[][][][][][] data)
创建一个int元素的rank-6张量。
static Tensor<Boolean>
create(boolean[][] data)
创建布尔元素的秩2张量。
static Tensor<Float>
create(float[][][][] data)
创建浮点元素的等级4张量。
static Tensor<Double>
create(double[][] data)
创建双元素的2级张量。
static Tensor<String>
create(byte[][][] data)
创建一个字节元素的秩2张量。
static Tensor<String>
create(byte[][][][][] data)
创建一个字节元素的等级4张量。
static Tensor<Float>
create(float[][][][][] data)
创建浮点元素的5级张量。
static Tensor<Integer>
create(int data)
创建包含单个int元素的标量张量。
static Tensor<Long>
create(long[][][][] data)
创建长元素的等级4张量。
static Tensor<Boolean>
create(boolean data)
创建一个包含单个布尔元素的标量张量。
static Tensor<Double>
create(double[][][] data)
创建双元素的三级张量。
static Tensor<Float>
create(float[][][][][][] data)
创建浮点元素的秩6张量。
static Tensor<String>
create(byte[] data)
创建包含单个字节元素的标量张量。
static Tensor<Integer>
create(int[][] data)
创建int元素的秩2张量。
static Tensor<Integer>
create(int[][][] data)
创建一个int元素的三级张量。
static Tensor<Boolean>
create(boolean[][][] data)
创建布尔元素的等级3张量。
static Tensor<Double>
create(double[][][][][] data)
创建双元素的秩-5张量。
static Tensor<Float>
create(float data)
创建包含单个浮点元素的标量张量。
static Tensor<Long>
create(long[][][] data)
创建长元素的三级张量。
static Tensor<Boolean>
create(boolean[][][][] data)
创建布尔元素的等级4张量。
static Tensor<Float>
create(float[][] data)
创建浮点元素的二级张量。
static Tensor<String>
create(byte[][][][] data)
创建一个字节元素的三级张量。
static Tensor<Long>
create(long[][][][][][] data)
创建长元素的6级张量。
static Tensor<Long>
create(long[][] data)
创建长元素的二级张量。
static Tensor<Boolean>
create(boolean[] data)
创建一个布尔元素的秩-1张量。
static Tensor<Float>
create(float[] data)
创建浮点元素的一级张量。
static Tensor<Long>
create(long[][][][][] data)
创建长元素的5级张量。
static Tensor<String>
create(String data)
使用默认的UTF-8编码创建标量字符串张量。
static Tensor<Double>
create(double[][][][] data)
创建双元素的4级张量。
static Tensor<Boolean>
create(boolean[][][][][][] data)
创建布尔元素的rank-6张量。
static Tensor<Integer>
create(int[][][][] data)
创建int元素的等级4张量。
static Tensor<Long>
create(long data)
创建包含单个长元素的标量张量。
static Tensor<String>
create(String data, Charset charset)
使用指定的编码创建标量字符串张量。
static Tensor<Double>
create(double[][][][][][] data)
创建一个双元素的6级张量。
static Tensor<Integer>
create(int[] data)
创建int元素的1级张量。
static Tensor<String>
create(byte[][][][][][] data)
创建一个字节元素的秩5张量。

继承的方法

从类java.lang.Object
公共方法
public static Tensor <Float> create(float [] [] [] data)
创建浮点元素的三级张量。


十八、图像分类与识别案例

上代码
import org.tensorflow.*;
import org.tensorflow.types.UInt8;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.List;
/** Sample use of the TensorFlow Java API to label images using a pre-trained model. */
public class LabelImage {
    private static void printUsage(PrintStream s) {
        final String url =
                "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";;;
        s.println(
                "Java program that uses a pre-trained Inception model (http://arxiv.org/abs/1512.00567)";;);
        s.println("to label JPEG images.");
        s.println("TensorFlow version: " + TensorFlow.version());
        s.println();
        s.println("Usage: label_image <model dir> <image file>");
        s.println();
        s.println("Where:");
        s.println("<model dir> is a directory containing the unzipped contents of the inception model");
        s.println("            (from " + url + ")");
        s.println("<image file> is the path to a JPEG image file");
    }
    public static void main(String[] args) {
//        if (args.length != 2) {
//            printUsage(System.err);
//            System.exit(1);
//        }
//        String modelDir = args[0];
//        String imageFile = args[1];
        String modelDir = "D:\TestJar\AutoTest\TensorFlow\src\main\resources\model";
        String imageFile ="D:\TestJar\AutoTest\TensorFlow\src\main\resources\images\example-400x288.jpg";
        byte[] graphDef = readAllBytesOrExit(Paths.get(modelDir, "tensorflow_inception_graph.pb"));
        List<String> labels =
                readAllLinesOrExit(Paths.get(modelDir, "imagenet_comp_graph_label_strings.txt"));
        byte[] imageBytes = readAllBytesOrExit(Paths.get(imageFile));
        try (Tensor<Float> image = constructAndExecuteGraphToNormalizeImage(imageBytes)) {
            float[] labelProbabilities = executeInceptionGraph(graphDef, image);
            int bestLabelIdx = maxIndex(labelProbabilities);
            System.out.println(
                    String.format("BEST MATCH: %s (%.2f%% likely)",
                            labels.get(bestLabelIdx),
                            labelProbabilities[bestLabelIdx] * 100f));
        }
    }
    private static Tensor<Float> constructAndExecuteGraphToNormalizeImage(byte[] imageBytes) {
        try (Graph g = new Graph()) {
            GraphBuilder b = new GraphBuilder(g);
            // Some constants specific to the pre-trained model at:
            // https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
            //
            // - The model was trained with images scaled to 224x224 pixels.
            // - The colors, represented as R, G, B in 1-byte each were converted to
            //   float using (value - Mean)/Scale.
            final int H = 224;
            final int W = 224;
            final float mean = 117f;
            final float scale = 1f;
            // Since the graph is being constructed once per execution here, we can use a constant for the
            // input image. If the graph were to be re-used for multiple input images, a placeholder would
            // have been more appropriate.
            final Output<String> input = b.constant("input", imageBytes);
            final Output<Float> output =
                    b.div(
                            b.sub(
                                    b.resizeBilinear(
                                            b.expandDims(
                                                    b.cast(b.decodeJpeg(input, 3), Float.class),
                                                    b.constant("make_batch", 0)),
                                            b.constant("size", new int[] {H, W})),
                                    b.constant("mean", mean)),
                            b.constant("scale", scale));
            try (Session s = new Session(g)) {
                return s.runner().fetch(output.op().name()).run().get(0).expect(Float.class);
            }
        }
    }
    private static float[] executeInceptionGraph(byte[] graphDef, Tensor<Float> image) {
        try (Graph g = new Graph()) {
            g.importGraphDef(graphDef);
            try (Session s = new Session(g);
                 Tensor<Float> result =
                         s.runner().feed("input", image).fetch("output").run().get(0).expect(Float.class)) {
                final long[] rshape = result.shape();
                if (result.numDimensions() != 2 || rshape[0] != 1) {
                    throw new RuntimeException(
                            String.format(
                                    "Expected model to produce a [1 N] shaped tensor where N is the number of labels, instead it produced one with shape %s",
                                    Arrays.toString(rshape)));
                }
                int nlabels = (int) rshape[1];
                return result.copyTo(new float[1][nlabels])[0];
            }
        }
    }
    private static int maxIndex(float[] probabilities) {
        int best = 0;
        for (int i = 1; i < probabilities.length; ++i) {
            if (probabilities[i] > probabilities[best]) {
                best = i;
            }
        }
        return best;
    }
    private static byte[] readAllBytesOrExit(Path path) {
        try {
            return Files.readAllBytes(path);
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(1);
        }
        return null;
    }
    private static List<String> readAllLinesOrExit(Path path) {
        try {
            return Files.readAllLines(path, Charset.forName("UTF-8"));
        } catch (IOException e) {
            System.err.println("Failed to read [" + path + "]: " + e.getMessage());
            System.exit(0);
        }
        return null;
    }
    // In the fullness of time, equivalents of the methods of this class should be auto-generated from
    // the OpDefs linked into libtensorflow_jni.so. That would match what is done in other languages
    // like Python, C++ and Go.
    static class GraphBuilder {
        GraphBuilder(Graph g) {
            this.g = g;
        }
        Output<Float> div(Output<Float> x, Output<Float> y) {
            return binaryOp("Div", x, y);
        }
        <T> Output<T> sub(Output<T> x, Output<T> y) {
            return binaryOp("Sub", x, y);
        }
        <T> Output<Float> resizeBilinear(Output<T> images, Output<Integer> size) {
            return binaryOp3("ResizeBilinear", images, size);
        }
        <T> Output<T> expandDims(Output<T> input, Output<Integer> dim) {
            return binaryOp3("ExpandDims", input, dim);
        }
        <T, U> Output<U> cast(Output<T> value, Class<U> type) {
            DataType dtype = DataType.fromClass(type);
            return g.opBuilder("Cast", "Cast")
                    .addInput(value)
                    .setAttr("DstT", dtype)
                    .build()
                    .<U>output(0);
        }
        Output<UInt8> decodeJpeg(Output<String> contents, long channels) {
            return g.opBuilder("DecodeJpeg", "DecodeJpeg")
                    .addInput(contents)
                    .setAttr("channels", channels)
                    .build()
                    .<UInt8>output(0);
        }
        <T> Output<T> constant(String name, Object value, Class<T> type) {
            try (Tensor<T> t = Tensor.<T>create(value, type)) {
                return g.opBuilder("Const", name)
                        .setAttr("dtype", DataType.fromClass(type))
                        .setAttr("value", t)
                        .build()
                        .<T>output(0);
            }
        }
        Output<String> constant(String name, byte[] value) {
            return this.constant(name, value, String.class);
        }
        Output<Integer> constant(String name, int value) {
            return this.constant(name, value, Integer.class);
        }
        Output<Integer> constant(String name, int[] value) {
            return this.constant(name, value, Integer.class);
        }
        Output<Float> constant(String name, float value) {
            return this.constant(name, value, Float.class);
        }
        private <T> Output<T> binaryOp(String type, Output<T> in1, Output<T> in2) {
            return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
        }
        private <T, U, V> Output<T> binaryOp3(String type, Output<U> in1, Output<V> in2) {
            return g.opBuilder(type, type).addInput(in1).addInput(in2).build().<T>output(0);
        }

     private Graph g;
    }
}
   

打赏

未经允许不得转载:同乐学堂 » TensorFlow Java版本入门

分享到:更多 ()

评论 抢沙发

  • 昵称 (必填)
  • 邮箱 (必填)
  • 网址

特别的技术,给特别的你!

联系QQ:1071235258QQ群:226134712
error: Sorry,暂时内容不可复制!