简体   繁体   中英

Why do I get worse results in Java than in Python when using the same Tensorflow models?

Introduction:

For education purpose I developed a Java class that enables students to load Tensorflow models in the Tensorflow SavedModel format and use them for classification purpose in Java. For example, they can create a model online with Google's Teachable Machine , download that and use that model right in Java. This also works with many image classification models on tfhub.dev . Thereby I tried to use the new but not well documented Java API and not the deprecated old libtensorflow-API (when I understood everything correctly). As I use BlueJ for that, everything is based on pure Java code linking the required libraries directly in BlueJ's preferences after downloading them. The documentation in the Java code shows where to download the libraries.

Note: I know that "the normal way today" is using Gradle or Maven or sth. but students do not work with these tools. Another note: In the following I only use a few Code excerpts in order to simplify everything to fit into this minimum example.

Problem:

The results of all my loaded models in Java are OK but not that performant as in Python resp. the online demonstrations linked on the Tensorflow website, mainly in Jupyter notebooks. So there seems to be one step wrong in my code.

As a representative test I will now compare the performance of the MoveNet model when using Python and Java. The MoveNet model "Thunder" detects 17 keypoints of a body in an image with 256x256 pixels. I will use exactly the same image (the same file without touching and resizing it) in both setups (I uploaded it to my webspace; this step was done when updating this text, however there are no differences in the results).

Python: The MoveNet Model comes with a nice online Python demo in a Jupyter notebook:

https://www.tensorflow.org/hub/tutorials/mov.net

The code can be found here (Note: I linked to the same image as in my Java project by uploading it to my webspace and linking to it) and the classification result of the image looks like this:

Python演示分类

Java: My Java-based approach ends up in an image like this:

Java演示分类

I think that it is not bad, but it isn't perfect. With other models like eg Google's imag.net_mobil.net model I get similar results that are ok, but I suppose they are always a bit better when running online demos in Jupyter notebooks. I do not have more evidence - only a feeling. Im some cases the same image from the online demo is recognized as a different class - but not always. I might provide more data on that later.

Assumption and work done yet:

There might be an error in the data structures or algorithms on them in my Java code. I really searched the web for some weeks now, but I am unsure if my code really is precise, mainly as there are too few examples out there. Eg, I tried to change the order of RGB or the way it is calculated in the method that converts an image into a ND array. However, I saw no significant changes. Maybe the error is anywhere else. However, probably it is just as it is. If my code works well and is correct, that it is also ok for me - but I am still wondering why there are differences. Thanks for answers!

Code:

Here is a fully working example with two classes (I know, the Frame with the Panel drawing is bad - I coded this just fast for this example)

/**
 * 1. TensorFlow Core API Library: org.tensorflow -> tensorflow-core-api
 *      https://mvnrepository.com/artifact/org.tensorflow/tensorflow-core-api
 *          -> tensorflow-core-api-0.4.0.jar
 *      
 * 2.   additionally click "View All" and open:
 *      https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/0.4.0/
 *      Download the correct native library for your OS
 *          -> tensorflow-core-api-0.4.0-macosx-x86_64.jar
 *          -> tensorflow-core-api-0.4.0-windows-x86_64.jar
 *          -> tensorflow-core-api-0.4.0-linux-x86_64.jar 
 *      
 * 3. TensorFlow Framework Library:  org.tensorflow -> tensorflow-framework
 *      https://mvnrepository.com/artifact/org.tensorflow/tensorflow-framework/0.4.0
 *          -> tensorflow-framework-0.4.0.jar      
 *          
 * 4. Protocol Buffers [Core]: com.google.protobuf -> protobuf-java
 *      https://mvnrepository.com/artifact/com.google.protobuf/protobuf-java
 *          -> protobuf-java-4.0.0-rc-2.jar
 * 
 * 5. JavaCPP: org.bytedeco -> javacpp
 *      https://mvnrepository.com/artifact/org.bytedeco/javacpp
 *          -> javacpp-1.5.7.jar
 * 
 * 6. TensorFlow NdArray Library:  org.tensorflow -> ndarray
 *      https://mvnrepository.com/artifact/org.tensorflow/ndarray
 *          -> ndarray-0.3.3.jar
 */
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.IntNdArray;
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.types.TInt32;
import java.util.HashMap;
import java.util.Map;
import java.awt.image.BufferedImage;
import javax.imageio.ImageIO;
import java.awt.Color;
import java.io.File;
import javax.swing.JFrame;
import javax.swing.JButton;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.BorderLayout;

public class MoveNetDemo {

    private SavedModelBundle model;
    private String inputLayerName;
    private String outputLayerName;
    private String keyName;
    private BufferedImage image;
    private float[][] output;    
    private int width;
    private int height;

    public MoveNetDemo(String pFoldername, int pImageWidth, int pImageHeight) {
        width = pImageWidth;
        height = pImageHeight;

        model = SavedModelBundle.load(pFoldername, "serve");
        // Read input and output layer names from file
        inputLayerName = model.signatures().get(0).getInputs().keySet().toString();
        outputLayerName = model.signatures().get(0).getOutputs().keySet().toString();
        inputLayerName = inputLayerName.substring(1, inputLayerName.length()-1);
        outputLayerName = outputLayerName.substring(1, outputLayerName.length()-1);
        keyName = model.signatures().get(0).key();        
    }

    // not necessary here
    public String getModelInformation() { 
        String infos = "";
        for (int i=0; i<model.signatures().size(); i++) {
            infos += model.signatures().get(i).toString();
        }         
        return infos;
    }  

    public void setData(String pFilename) {
        image = null;
        try {
            image = ImageIO.read(new File(pFilename));            
        } 
        catch (Exception e) {          
        }
    }

    public BufferedImage getData() {
        return image;
    }

    private IntNdArray fillIntNdArray(IntNdArray pMatrix, BufferedImage pImage) {        
        try {
            int w = pImage.getWidth();
            int h = pImage.getHeight();                

            for (int i = 0; i < h; i++) {
                for (int j = 0; j < w; j++) {                 
                    Color mycolor = new Color(pImage.getRGB(j, i));
                    int red = mycolor.getRed();
                    int green = mycolor.getGreen();
                    int blue = mycolor.getBlue();
                    pMatrix.setInt(red, 0, j, i, 0);
                    pMatrix.setInt(green, 0, j, i, 1);
                    pMatrix.setInt(blue, 0, j, i, 2);                                       
                }
            }
        }
        catch (Exception e) {            
        }
        return pMatrix;        
    }

    public void run() {
        Map<String, Tensor> feed_dict = null;
        IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1, width, height, 3));
        input_matrix = fillIntNdArray(input_matrix, image);            
        Tensor input_tensor = TInt32.tensorOf(input_matrix);
        feed_dict = new HashMap<>();
        feed_dict.put(inputLayerName, input_tensor); 
        Map<String, Tensor> res = model.function(keyName).call(feed_dict);                
        Tensor output_tensor = res.get(outputLayerName); 

        output = new float[17][3];
        for (int i= 0; i<17; i++) {
            output[i][0] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3)*256;                
            output[i][1] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3+1)*256;                
            output[i][2] = output_tensor.asRawTensor().data().asFloats().getFloat(i*3+2);
        }
    }

    public float[][] getOutputArray() {
        return output;
    }

    public static void main(String[] args) {
        MoveNetDemo im = new MoveNetDemo("/Users/myname/Downloads/Code/TF_Test_04_NEW/movenet_singlepose_thunder_4", 256, 256);        
        im.setData("/Users/myname/Downloads/Code/TF_Test_04_NEW/test.jpeg");

        JFrame jf = new JFrame("TEST");
        jf.setSize(300, 300);
        jf.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
        ImagePanel ip = new ImagePanel(im.getData());
        jf.add(ip, BorderLayout.CENTER);

        JButton st = new JButton("RUN");
        st.addActionListener(new ActionListener() { 
                public void actionPerformed(ActionEvent e) {
                    im.run();                            
                    ip.update(im.getOutputArray());
                    
                } 
            });        
        jf.add(st, BorderLayout.NORTH);

        jf.setVisible(true);
    }
}

and the ImagePanel class:

import javax.swing.JPanel;
import java.awt.image.BufferedImage;
import java.awt.Graphics;
import java.awt.Color;

public class ImagePanel extends JPanel {

    private BufferedImage image;
    private float[][] points;

    public ImagePanel(BufferedImage pImage) {        
        image = pImage;        
    }

    public void update(float[][] pPoints) {
        points = pPoints;
        repaint();
    }

    @Override
    protected void paintComponent(Graphics g) {                
        super.paintComponent(g);        
        g.drawImage(image, 0,0,null);
        g.setColor(Color.GREEN);
        if (points != null) {
            for (int j=0; j<17; j++) {                            
                g.fillOval((int)points[j][0], (int)points[j][1], 5, 5);
            } 
        }
    }
}

I found the answer. I mixed up height and width twice, No idea. why this behaves so strange (nearly correct but not perfect) but it works now.

In the Jupyter notebook it says:

input_image: A [1, height, width, 3]

so I changed the method fillIntArray to:

private IntNdArray fillIntNdArray(IntNdArray pMatrix, BufferedImage pImage) {        
        try {
            int w = pImage.getWidth();
            int h = pImage.getHeight();                

            for (int i = 0; i < h; i++) {
                for (int j = 0; j < w; j++) {                 
                    Color mycolor = new Color(pImage.getRGB(j, i));
                    int red = mycolor.getRed();
                    int green = mycolor.getGreen();
                    int blue = mycolor.getBlue();
                    pMatrix.setInt(red, 0, i, j, 0); // switched j and i 
                    pMatrix.setInt(green, 0, i, j, 1); // switched j and i 
                    pMatrix.setInt(blue, 0, i, j, 2); // switched j and i                                    
                }
            }
        }
        catch (Exception e) {            
        }
        return pMatrix;        
    }

and accordingly in the run()-method:

IntNdArray input_matrix = NdArrays.ofInts(Shape.of(1, height, width, 3));

In the Jupyter notebook you can toggle the helper functions for visualization and see that at first y and then x coordinates are taken. Height first, then width. Changing this in the ImagePanel class too, solves the problem and the classification is as expected and the same quality as in the online demonstration!

if (points != null) {
    for (int j=0; j<17; j++) {                            
        // switched 0 and 1
        g.fillOval((int)points[j][1], (int)points[j][0], 5, 5);
    } 
}

Here it is: 在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM