简体   繁体   中英

Same tensorflow model inference in Tomcat get different result from simple java application

We're deploying tensorflow model(seq2seq question answering) in Tomcat7(java1.8) , and while debugging, we just use simple java Application(public static void main() function) to test the model inference result. The inference result in simple java application is the same as in python original version. But when we launch the whole package(WAR) in Tomcat, it gets quite different results, while the inference code/test input sentence/model files are all the same.

Can anyone give us some hints about this problem?

  1. Simple java application(public static void main() function) get same results as python tensorflow version inferenced results. We treat them as correct ones.
  2. Tomcat loaded model get different results . The result looks like normal sentence, but answer meaning is quite bad when considering the questions.
  3. Model files(protobuf)/ java code / test input sentences are the same in above two cases.
  4. Keepout probability is 1.0f for inference.

Model loading function:

@Override
public boolean reload(String modelURL) {
    logger.info("tensorflow version:{}", TensorFlow.version());
    try {
        logger.info("start to download model path:{}", modelURL);
        //TODO: download model
        logger.info("start to load model path:{} tag:{}", MODEL_PATH, MODEL_TAG);
        bundle = SavedModelBundle.load(MODEL_PATH, MODEL_TAG);
        session = bundle.session();
        logger.info("finish loading model!");

    } catch(Exception e) {
        logger.error("reload model exception:", e);
        return false;
    }

    return true;
}

inference code:

    @Override
public String predict(String query, String candidateAnswer) {
    if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
        logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
        return null;
    }
    String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
    String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);

    try(Tensor queryTensor = Tensor.create(queryPad.getBytes());
        Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
        Tensor candidateTensor = Tensor.create(candidatePad.getBytes());
        Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
    {
        List<Tensor> result = session.runner()
                .feed("source_tokens", queryTensor)
                .feed("source_len", queryLenTensor)
                .feed("source_candidate_tokens", candidateTensor)
                .feed("source_candidate_len", candidateLenTensor)
                .fetch("model/att_seq2seq/predicted_tokens_scalar")
                .run();

        Tensor predictedTensor = result.get(0);
        String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
        logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
        return predictedTokens;
    } catch (Exception e) {
        logger.error("exception:", e);
    }

    return null;
}

Yeah, It's the encoding problem. When we launch the model in simple java application(public static void main()), its default encoding is UTF-8 while calling getBytes(). But when we launch the model in tomcat, its encoding scheme is ISO-8859-1.

Tensor queryTensor = Tensor.create( queryPad.getBytes("UTF-8") )

Tensor candidateTensor = Tensor.create( candidatePad.getBytes("UTF-8") )

    @Override
public String predict(String query, String candidateAnswer) {
    if (StringUtils.isEmpty(query) || StringUtils.isEmpty(candidateAnswer)) {
        logger.info(String.format("query:%s candidate:%s can't be empty or null!", query, candidateAnswer));
        return null;
    }
    String queryPad = preprocess(query, SEQUENCE_MAX_LEN);
    String candidatePad = preprocess(candidateAnswer, SEQUENCE_MAX_LEN);

    try(Tensor queryTensor = Tensor.create(queryPad.getBytes("UTF-8"));
        Tensor queryLenTensor = Tensor.create(SEQUENCE_MAX_LEN);
        Tensor candidateTensor = Tensor.create(candidatePad.getBytes("UTF-8"));
        Tensor candidateLenTensor = Tensor.create(SEQUENCE_MAX_LEN))
    {
        List<Tensor> result = session.runner()
                .feed("source_tokens", queryTensor)
                .feed("source_len", queryLenTensor)
                .feed("source_candidate_tokens", candidateTensor)
                .feed("source_candidate_len", candidateLenTensor)
                .fetch("model/att_seq2seq/predicted_tokens_scalar")
                .run();

        Tensor predictedTensor = result.get(0);
        String predictedTokens = new String(predictedTensor.bytesValue(), "UTF-8");
        logger.info(String.format("biseq2seq model generate:\nquery:%s\ncandidate:%s\npredict_tokens:%s", query.trim(), candidateAnswer.trim(), predictedTokens));
        return predictedTokens;
    } catch (Exception e) {
        logger.error("exception:", e);
    }

    return null;
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM