繁体   English   中英

如何将输入数据传递给 Java 中现有的 tensorflow 2.x 模型?

[英]How to pass input data to an existing tensorflow 2.x model in Java?

我正在用tensorflow做我的tensorflow 在 Python 中为 MNIST 数据创建了一个简单的模型后,我现在想将此模型导入 Java 并将其用于分类。 但是,我无法将输入数据传递给模型。

这是用于模型创建的 Python 代码:

 from tensorflow.keras.datasets import mnist
 from tensorflow.keras.utils import to_categorical.

 (train_images, train_labels), (test_images, test_labels) = mnist.load_data()

 train_images = train_images.reshape((60000, 28, 28, 1))
 train_images = train_images.astype('float32')
 train_images /= 255

 test_images = test_images.reshape((10000, 28, 28, 1))
 test_images = test_images.astype('float32')
 test_images /= 255

 train_labels = to_categorical(train_labels)
 test_labels = to_categorical(test_labels)

 NrTrainimages = train_images.shape[0]
 NrTestimages = test_images.shape[0]

 import os
 import numpy as np

 from tensorflow.keras.callbacks import TensorBoard
 from tensorflow.keras.models import Sequential
 from tensorflow.keras.layers import Dense, Dropout, Flatten
 from tensorflow.keras.layers import Conv2D, MaxPooling2D
 from tensorflow.keras import backend as K

 # Network architecture
 model = Sequential()
 mnist_inputshape = train_images.shape[1:4]

 # Convolutional block 1
 model.add(Conv2D(32, kernel_size=(5,5), 
       activation = 'relu',
       input_shape=mnist_inputshape,
       name = 'Input_Layer'))
 model.add(MaxPooling2D(pool_size=(2,2)))
 # Convolutional block 2
 model.add(Conv2D(64, kernel_size=(5,5),activation= 'relu'))
 model.add(MaxPooling2D(pool_size=(2,2)))
 model.add(Dropout(0.5))

 # Prediction block
 model.add(Flatten())
 model.add(Dense(128, activation='relu', name='features'))
 model.add(Dropout(0.5))
 model.add(Dense(64, activation='relu'))
 model.add(Dense(10, activation='softmax', name = 'Output_Layer'))

 model.compile(loss='categorical_crossentropy',
              optimizer='Adam',
              metrics=['accuracy'])

 LOGDIR = "logs"
 my_tensorboard = TensorBoard(log_dir = LOGDIR,
       histogram_freq=0,
       write_graph=True,
       write_images=True)
 my_batch_size = 128
 my_num_classes = 10
 my_epochs = 5

 history = model.fit(train_images, train_labels,
       batch_size=my_batch_size,
       callbacks=[my_tensorboard],
       epochs=my_epochs,
       use_multiprocessing=False,
       verbose=1,
       validation_data=(test_images, test_labels))

 score = model.evaluate(test_images, test_labels)

 modeldir = 'models'
 model.save(modeldir, save_format = 'tf')

对于Java ,我正在尝试调整此处发布的App.java代码。

我正在努力替换这个片段:

 Tensor result = s.runner()
      .feed("input_tensor", inputTensor)
      .feed("dropout/keep_prob", keep_prob)
      .fetch("output_tensor")
      .run().get(0);

虽然在这段代码中,一个特定的输入张量用于传递数据,但在我的模型中,只有层,没有单独的命名张量。 因此,以下不起作用:

 Tensor<?> result = s.runner()
      .feed("Input_Layer/kernel", inputTensor)
      .fetch("Output_Layer/kernel")
      .run().get(0);

如何在 Java 中将数据传递给模型并从中获取输出?

我终于设法找到了解决方案。 为了获取图中的所有张量名称,我使用了以下代码:

        for (Iterator it = smb.graph().operations(); it.hasNext();) {
            Operation op = (Operation) it.next();
            System.out.println("Operation name: " + op.name());
        }

从这里,我发现以下工作:

        SavedModelBundle smb = SavedModelBundle.load("./model", "serve");
        Session s = smb.session();

        Tensor<Float> inputTensor = Tensor.<Float>create(imagesArray, Float.class);
        Tensor<Float> result = s.runner()
                .feed("serving_default_Input_Layer_input", inputTensor)
                .fetch("StatefulPartitionedCall")
                .run().get(0).expect(Float.class);

使用最新版本的TensorFlow Java ,您无需从模型签名或图中自行搜索输入/输出张量的名称。 您可以简单地调用以下内容:

try (SavedModelBundle model = SavedModelBundle.load("./model", "serve");
     Tensor<TFloat32> image = TFloat32.tensorOf(...); // There a many ways to pass you image bytes here
     Tensor<TFloat32> result = model.call(image).expect(TFloat32.DTYPE)) {
    System.out.println("Result is " + result.data().getFloat());
  }
}

TensorFlow Java 将自动负责将您的输入/输出张量映射到正确的节点。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM