简体   繁体   中英

gRPC-only Tensorflow Serving client in C++

There seems to be a bit of information out there for creating a gRPC -only client in Python (and even a few other languages) and I was able to successfully get a working client that uses only gRPC in Python that works for our implementation.

What I can't seem to find is a case where someone has successfully written the client in C++.

The constraints of the task are as follows:

  1. The build system cannot be bazel , because the final application already has its own build system.
  2. The client cannot include Tensorflow (which requires bazel to build against in C++).
  3. The application should use gRPC and not HTTP calls for speed.
  4. The application ideally won't call Python or otherwise execute shell commands.

Given the above constraints, and assuming that I extracted and generated the gRPC stubs, is this even possible? If so, can an example be provided?

Turns out, this isn't anything new if you have already done it in Python. Assuming the model has been named "predict" and the input to the model is called "inputs," the following is the Python code:

import logging
import grpc
from grpc import RpcError

from types_pb2 import DT_FLOAT
from tensor_pb2 import TensorProto
from tensor_shape_pb2 import TensorShapeProto
from predict_pb2 import PredictRequest
from prediction_service_pb2_grpc import PredictionServiceStub


class ModelClient:
    """Client Facade to work with a Tensorflow Serving gRPC API"""
    host = None
    port = None
    chan = None
    stub = None

    logger = logging.getLogger(__name__)

    def __init__(self, name, dims, dtype=DT_FLOAT, version=1):
        self.model = name
        self.dims = [TensorShapeProto.Dim(size=dim) for dim in dims]
        self.dtype = dtype
        self.version = version

    @property
    def hostport(self):
        """A host:port string representation"""
        return f"{self.host}:{self.port}"

    def connect(self, host='localhost', port=8500):
        """Connect to the gRPC server and initialize prediction stub"""
        self.host = host
        self.port = int(port)

        self.logger.info(f"Connecting to {self.hostport}...")
        self.chan = grpc.insecure_channel(self.hostport)

        self.logger.info("Initializing prediction gRPC stub.")
        self.stub = PredictionServiceStub(self.chan)

    def tensor_proto_from_measurement(self, measurement):
        """Pass in a measurement and return a tensor_proto protobuf object"""
        self.logger.info("Assembling measurement tensor.")
        return TensorProto(
            dtype=self.dtype,
            tensor_shape=TensorShapeProto(dim=self.dims),
            string_val=[bytes(measurement)]
        )

    def predict(self, measurement, timeout=10):
        """Execute prediction against TF Serving service"""
        if self.host is None or self.port is None \
                or self.chan is None or self.stub is None:
            self.connect()

        self.logger.info("Creating request.")
        request = PredictRequest()
        request.model_spec.name = self.model

        if self.version > 0:
            request.model_spec.version.value = self.version

        request.inputs['inputs'].CopyFrom(
            self.tensor_proto_from_measurement(measurement))

        self.logger.info("Attempting to predict against TF Serving API.")
        try:
            return self.stub.Predict(request, timeout=timeout)
        except RpcError as err:
            self.logger.error(err)
            self.logger.error('Predict failed.')
            return None

The following is a working (rough) C++ translation:

#include <iostream>
#include <memory>
#include <string>

#include <grpcpp/grpcpp.h>

#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "google/protobuf/map.h"

#include "types.grpc.pb.h"
#include "tensor.grpc.pb.h"
#include "tensor_shape.grpc.pb.h"
#include "predict.grpc.pb.h"
#include "prediction_service.grpc.pb.h"

using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;

using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
using tensorflow::serving::PredictRequest;
using tensorflow::serving::PredictResponse;
using tensorflow::serving::PredictionService;

typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;

class ServingClient {
 public:
  ServingClient(std::shared_ptr<Channel> channel)
      : stub_(PredictionService::NewStub(channel)) {}

  // Assembles the client's payload, sends it and presents the response back
  // from the server.
  std::string callPredict(const std::string& model_name,
                          const float& measurement) {

    // Data we are sending to the server.
    PredictRequest request;
    request.mutable_model_spec()->set_name(model_name);

    // Container for the data we expect from the server.
    PredictResponse response;

    // Context for the client. It could be used to convey extra information to
    // the server and/or tweak certain RPC behaviors.
    ClientContext context;

    google::protobuf::Map<std::string, tensorflow::TensorProto>& inputs =
      *request.mutable_inputs();

    tensorflow::TensorProto proto;
    proto.set_dtype(tensorflow::DataType::DT_FLOAT);
    proto.add_float_val(measurement);

    proto.mutable_tensor_shape()->add_dim()->set_size(5);
    proto.mutable_tensor_shape()->add_dim()->set_size(8);
    proto.mutable_tensor_shape()->add_dim()->set_size(105);

    inputs["inputs"] = proto;

    // The actual RPC.
    Status status = stub_->Predict(&context, request, &response);

    // Act upon its status.
    if (status.ok()) {
      std::cout << "call predict ok" << std::endl;
      std::cout << "outputs size is " << response.outputs_size() << std::endl;

      OutMap& map_outputs = *response.mutable_outputs();
      OutMap::iterator iter;
      int output_index = 0;

      for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
        tensorflow::TensorProto& result_tensor_proto = iter->second;
        std::string section = iter->first;
        std::cout << std::endl << section << ":" << std::endl;

        if ("classes" == section) {
          int titer;
          for (titer = 0; titer != result_tensor_proto.int64_val_size(); ++titer) {
            std::cout << result_tensor_proto.int64_val(titer) << ", ";
          }
        } else if ("scores" == section) {
          int titer;
          for (titer = 0; titer != result_tensor_proto.float_val_size(); ++titer) {
            std::cout << result_tensor_proto.float_val(titer) << ", ";
          }
        }
        std::cout << std::endl;
        ++output_index;
      }
      return "Done.";
    } else {
      std::cout << "gRPC call return code: " << status.error_code() << ": "
                << status.error_message() << std::endl;
      return "RPC failed";
    }
  }

 private:
  std::unique_ptr<PredictionService::Stub> stub_;
};

Note that the dimensions here have been specified within the code instead of passed in.

Given the above class, execution can then be as follows:

int main(int argc, char** argv) {
  float measurement[5*8*105] = { ... data ... };

  ServingClient sclient(grpc::CreateChannel(
      "localhost:8500", grpc::InsecureChannelCredentials()));
  std::string model("predict");
  std::string reply = sclient.callPredict(model, *measurement);
  std::cout << "Predict received: " << reply << std::endl;

  return 0;
}

The Makefile used was borrowed from the gRPC C++ examples, with the PROTOS_PATH variable set relative to the Makefile and the following build target (assuming the C++ application is named predict.cc ):

predict: types.pb.o types.grpc.pb.o tensor_shape.pb.o tensor_shape.grpc.pb.o resource_handle.pb.o resource_handle.grpc.pb.o model.pb.o model.grpc.pb.o tensor.pb.o tensor.grpc.pb.o predict.pb.o predict.grpc.pb.o prediction_service.pb.o prediction_service.grpc.pb.o predict.o
    $(CXX) $^ $(LDFLAGS) -o $@

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM