简体   繁体   English

如何在 Java 的 TensorFlow 模型中提供稀疏占位符

[英]How can I feed a sparse placeholder in a TensorFlow model from Java

I'm trying to calculate the best match for a given address with the kNN algorithm in TensorFlow, which works pretty good, but when I'm trying to export the model and use it in our Java Environment I got stuck on how to feed the sparse placholders from Java.我正在尝试使用 TensorFlow 中的 kNN 算法计算给定地址的最佳匹配,该算法效果很好,但是当我尝试导出模型并在我们的 Java 环境中使用它时,我陷入了如何提供来自 Java 的稀疏占位符。

Here is a pretty much stripped down version of the python part, which returns the smallest distance between the test name and the best reference name.这是 python 部分的一个非常精简的版本,它返回测试名称和最佳参考名称之间的最小距离。 So far this work's as expected.到目前为止,这项工作符合预期。 When I export the model and import it in my Java program it always returns the same value (distance of the placeholders default).当我导出模型并将其导入我的 Java 程序时,它总是返回相同的值(占位符默认的距离)。 I asume, that the python function sparse_from_word_vec(word_vec) isn't in the model, which would totally make sense to me, but then how should i make this sparse tensor?我假设,python 函数sparse_from_word_vec(word_vec)不在模型中,这对我来说完全有意义,但是我应该如何制作这个稀疏张量? My input is a single string and I need to create a fitting sparse tensor (value) to calculate the distance.我的输入是一个字符串,我需要创建一个合适的稀疏张量(值)来计算距离。 I also searched for a way to generate the sparse tensor on the Java side, but without success.我也在Java端搜索了一种生成稀疏张量的方法,但没有成功。

import tensorflow as tf
import pandas as pd

d = {'NAME': ['max mustermann', 
              'erika musterfrau', 
              'joseph haydn', 
              'johann sebastian bach', 
              'wolfgang amadeus mozart']}

df = pd.DataFrame(data=d)  

input_name = tf.placeholder_with_default('max musterman',(), name='input_name')
output_dist = tf.placeholder(tf.float32, (), name='output_dist')

test_name = tf.sparse_placeholder(dtype=tf.string)
ref_names = tf.sparse_placeholder(dtype=tf.string)

output_dist = tf.edit_distance(test_name, ref_names, normalize=True)

def sparse_from_word_vec(word_vec):
    num_words = len(word_vec)
    indices = [[xi, 0, yi] for xi,x in enumerate(word_vec) for yi,y in enumerate(x)]
    chars = list(''.join(word_vec))
    return(tf.SparseTensorValue(indices, chars, [num_words,1,1]))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    t_data_names=tf.constant(df['NAME'])
    reference_names = [el.decode('UTF-8') for el in (t_data_names.eval())]

    sparse_ref_names = sparse_from_word_vec(reference_names)
    sparse_test_name = sparse_from_word_vec([str(input_name.eval().decode('utf-8'))]*5)

    feeddict={test_name: sparse_test_name,
              ref_names: sparse_ref_names, 
              }    

    output_dist = sess.run(output_dist, feed_dict=feeddict)
    output_dist = tf.reduce_min(output_dist, 0)
    print(output_dist.eval())

    tf.saved_model.simple_save(sess,
                               "model-simple",
                               inputs={"input_name": input_name},
                               outputs={"output_dist": output_dist})

And here is my Java method:这是我的 Java 方法:

public void run(ApplicationArguments args) throws Exception {
  log.info("Loading model...");

  SavedModelBundle savedModelBundle = SavedModelBundle.load("/model", "serve");

  byte[] test_name = "Max Mustermann".toLowerCase().getBytes("UTF-8");


  List<Tensor<?>> output = savedModelBundle.session().runner()
      .feed("input_name", Tensor.<String>create(test_names))
      .fetch("output_dist")
      .run();

  System.out.printl("Nearest distance: " + output.get(0).floatValue());

}

I was able to get your example working.我能够让你的例子工作。 I have a couple of comments on your python code before diving in.在深入研究之前,我对您的 Python 代码有一些评论。

You use the variable output_dist for 3 different value types throughout the code.您在整个代码中将变量output_dist用于 3 种不同的值类型。 I'm not a python expert, but I think it's bad practice.我不是 Python 专家,但我认为这是不好的做法。 You also never actually use the input_name placeholder, except for exporting it as an input.您也从未实际使用input_name占位符,除非将其导出为输入。 Last one is that tf.saved_model.simple_save is deprecated, and you should use the tf.saved_model.Builder instead.最后一个是tf.saved_model.simple_save已弃用,您应该改用tf.saved_model.Builder

Now for the solution.现在为解决方案。

Looking at the libtensorflow jar file using the command jar tvf libtensorflow-xxxjar (thanks to this post), you can see that there are no useful bindings for creating a sparse tensor (maybe make a feature request?).使用命令jar tvf libtensorflow-xxxjar (感谢这篇文章)查看libtensorflow jar 文件,您可以看到没有用于创建稀疏张量的有用绑定(也许提出功能请求?)。 So we have to change the input to a dense tensor, then add operations to the graph to convert it to sparse.因此,我们必须将输入更改为稠密张量,然后向图中添加操作以将其转换为稀疏张量。 In your original code the sparse conversion was on the python side which means that the loaded graph in java wouldn't have any ops for it.在您的原始代码中,稀疏转换在 python 端,这意味着 java 中加载的图形不会有任何操作。

Here is the new python code:这是新的python代码:

import tensorflow as tf
import pandas as pd

def model():
    #use dense tensors then convert to sparse for edit_distance
    test_name = tf.placeholder(shape=(None, None), dtype=tf.string, name="test_name")
    ref_names = tf.placeholder(shape=(None, None), dtype=tf.string, name="ref_names")

    #Java Does not play well with the empty character so use "/" instead
    test_name_sparse = tf.contrib.layers.dense_to_sparse(test_name, "/")
    ref_names_sparse = tf.contrib.layers.dense_to_sparse(ref_names, "/")

    output_dist = tf.edit_distance(test_name_sparse, ref_names_sparse, normalize=True)

    #output the index to the closest ref name
    min_idx = tf.argmin(output_dist)

    return test_name, ref_names, min_idx

#Python code to be replicated in Java
def pad_string(s, max_len):
    return s + ["/"] * (max_len - len(s))

d = {'NAME': ['joseph haydn', 
              'max mustermann', 
              'erika musterfrau', 
              'johann sebastian bach', 
              'wolfgang amadeus mozart']}

df = pd.DataFrame(data=d)  
input_name = 'max musterman'

#pad dense tensor input
max_len = max([len(n) for n in df['NAME']])

test_input = [list(input_name)]*len(df['NAME'])
#no need to pad, all same length
ref_input = list(map(lambda x: pad_string(x, max_len), [list(n) for n in df['NAME']]))


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    test_name, ref_names, min_idx = model()

    #run a test to make sure the model works
    feeddict = {test_name: test_input,
                ref_names: ref_input,
            }
    out = sess.run(min_idx, feed_dict=feeddict)
    print("test output:", out)

    #save the model with the new Builder API
    signature_def_map= {
    "predict": tf.saved_model.signature_def_utils.predict_signature_def(
        inputs= {"test_name": test_name, "ref_names": ref_names},
        outputs= {"min_idx": min_idx})
    }

    builder = tf.saved_model.Builder("model")
    builder.add_meta_graph_and_variables(sess, ["serve"], signature_def_map=signature_def_map)
    builder.save()

And here is the java to load and run it.这是加载和运行它的java。 There is probably a lot of room for improvement here (java isn't my main language), but it gives you the idea.这里可能有很大的改进空间(java 不是我的主要语言),但它给了你这个想法。

import org.tensorflow.Graph;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.SavedModelBundle;

import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;

public class Test {
    public static byte[][] makeTensor(String s, int padding) throws Exception
    {
        int len = s.length();
        int extra = padding - len;

        byte[][] ret = new byte[len + extra][];
        for (int i = 0; i < len; i++) {
            String cur = "" + s.charAt(i);
            byte[] cur_b = cur.getBytes("UTF-8");
            ret[i] = cur_b;
        }

        for (int i = 0; i < extra; i++) {
            byte[] cur = "/".getBytes("UTF-8");
            ret[len + i] = cur;
        }

        return ret;
    }
    public static byte[][][] makeTensor(List<String> l, int padding) throws Exception
    {
        byte[][][] ret = new byte[l.size()][][];
        for (int i = 0; i < l.size(); i++) {
            ret[i] = makeTensor(l.get(i), padding);
        }

        return ret;
    }
    public static void main(String[] args) throws Exception {
        System.out.println("Loading model...");

        SavedModelBundle savedModelBundle = SavedModelBundle.load("model", "serve");


        List<String> str_test_name = Arrays.asList("Max Mustermann",
            "Max Mustermann",
            "Max Mustermann",
            "Max Mustermann",
            "Max Mustermann");
        List<String> names = Arrays.asList("joseph haydn",
            "max mustermann",
            "erika musterfrau",
            "johann sebastian bach",
            "wolfgang amadeus mozart");

        //get the max length for each array
        int pad1 = str_test_name.get(0).length();
        int pad2 = 0;
        for (String var : names) {
            if(var.length() > pad2)
                pad2 = var.length();
        }


        byte[][][] test_name = makeTensor(str_test_name, pad1);
        byte[][][] ref_names = makeTensor(names, pad2);

        //use a with block so the close method is called
        try(Tensor t_test_name = Tensor.<String>create(test_name))
        {
            try (Tensor t_ref_names = Tensor.<String>create(ref_names))
            {
                List<Tensor<?>> output = savedModelBundle.session().runner()
                    .feed("test_name", t_test_name)
                    .feed("ref_names", t_ref_names)
                    .fetch("ArgMin")
                    .run();

                System.out.println("Nearest distance: " + output.get(0).longValue());
            }
        }
    }
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何从Java Servlet检索JSON中的提要? - How can I retrieve a feed in JSON from a Java Servlet? 如何在 Python 中使用 tensorflow 训练图像分类器模型并在 Java 应用程序中使用训练后的模型? - How can i train a Image classifier model with tensorflow in Python and use the trained model in Java application? 如何从 TensorFlow Java 调用 model? - How to invoke model from TensorFlow Java? 如何在Java中有效访问半稀疏数据? - How can I access semi-sparse data efficiently in java? 如何通过selenium java从提示中获取占位符文本 - How can i get placeholder text from prompt via selenium java 如何将占位符添加到随机 Int 然后从 Java 中的 Int 中提取一个数字? - How can I add a placeholder to a random Int then pull a single digit from that Int in Java? 如何通过TensorFlowInferenceInterface.java提供布尔占位符? - How to feed boolean placeholder by means of TensorFlowInferenceInterface.java? 如何使用TensorFlow Java API移除预训练模型的输出层? - How can I remove the output layer of pre-trained model with TensorFlow java api? 如何在 Java 中为 TensorFlow 创建 TensorProto? - How can I create TensorProto for TensorFlow in Java? 我们如何将 tensorflow 2.x 模型导入 Java? - How can we import a tensorflow 2.x model to Java?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM