发现网上大量的代码都是mnist,我自己反正不是搞图像处理的,所以这个例子我怎么都不想搞;
wide&deep这种,包含各种特征的模型,才是我的需要,iris也是从文本训练模型,所以非常简单;
本文给出Python和Java访问Tensorflow的Serving代码。
Java版本使用Grpc访问Tensorflow的Serving代码
package io.github.qf6101.tensorflowserving; import com.google.protobuf.ByteString; import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.netty.NegotiationType; import io.grpc.netty.NettyChannelBuilder; import org.tensorflow.example.*; import org.tensorflow.framework.DataType; import org.tensorflow.framework.TensorProto; import org.tensorflow.framework.TensorShapeProto; import tensorflow.serving.Model; import tensorflow.serving.Predict; import tensorflow.serving.PredictionServiceGrpc; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * 参考:https://www.jianshu.com/p/d82107165119 * 参考:https://github.com/grpc/grpc-java */ public class PssIrisGrpcClient { public static Example createExample() { Features.Builder featuresBuilder = Features.newBuilder(); Map<String, Float> dataMap = new HashMap<String, Float>(); dataMap.put("SepalLength", 5.1f); dataMap.put("SepalWidth", 3.3f); dataMap.put("PetalLength", 1.7f); dataMap.put("PetalWidth", 0.5f); Map<String, Feature> featuresMap = mapToFeatureMap(dataMap); featuresBuilder.putAllFeature(featuresMap); Features features = featuresBuilder.build(); Example.Builder exampleBuilder = Example.newBuilder(); exampleBuilder.setFeatures(features); return exampleBuilder.build(); } private static Map<String, Feature> mapToFeatureMap(Map<String, Float> dataMap) { Map<String, Feature> resultMap = new HashMap<String, Feature>(); for (String key : dataMap.keySet()) { // // data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5} FloatList floatList = FloatList.newBuilder().addValue(dataMap.get(key)).build(); Feature feature = Feature.newBuilder().setFloatList(floatList).build(); resultMap.put(key, feature); } return resultMap; } public static void main(String[] args) { String host = "127.0.0.1"; int port = 8888; ManagedChannel channel = ManagedChannelBuilder.forAddress(host, port) // Channels are secure by default (via SSL/TLS). For the example we disable TLS to avoid // needing certificates. .usePlaintext() .build(); PredictionServiceGrpc.PredictionServiceBlockingStub blockingStub = PredictionServiceGrpc.newBlockingStub(channel); com.google.protobuf.Int64Value version = com.google.protobuf.Int64Value.newBuilder() .setValue(1) .build(); Model.ModelSpec modelSpec = Model.ModelSpec.newBuilder() .setName("iris") .setVersion(version) .setSignatureName("classification") .build(); List<ByteString> exampleList = new ArrayList<ByteString>(); exampleList.add(createExample().toByteString()); TensorShapeProto.Dim featureDim = TensorShapeProto.Dim.newBuilder().setSize(exampleList.size()).build(); TensorShapeProto shapeProto = TensorShapeProto.newBuilder().addDim(featureDim).build(); org.tensorflow.framework.TensorProto tensorProto = TensorProto.newBuilder().addAllStringVal(exampleList).setDtype(DataType.DT_STRING).setTensorShape(shapeProto).build(); Predict.PredictRequest request = Predict.PredictRequest.newBuilder() .setModelSpec(modelSpec) .putInputs("inputs", tensorProto) .build(); tensorflow.serving.Predict.PredictResponse response = blockingStub.predict(request); System.out.println(response); channel.shutdown(); } }
需要增加如下maven依赖:
<!-- https://mvnrepository.com/artifact/org.tensorflow/tensorflow --> <dependency> <groupId>org.tensorflow</groupId> <artifactId>tensorflow</artifactId> <version>1.12.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-netty --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-netty</artifactId> <version>1.20.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-protobuf --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-protobuf</artifactId> <version>1.20.0</version> </dependency> <!-- https://mvnrepository.com/artifact/io.grpc/grpc-stub --> <dependency> <groupId>io.grpc</groupId> <artifactId>grpc-stub</artifactId> <version>1.20.0</version> </dependency>
输出结果:
outputs { key: "scores" value { dtype: DT_FLOAT tensor_shape { dim { size: 1 } dim { size: 3 } } float_val: 0.9997806 float_val: 2.1938368E-4 float_val: 1.382611E-9 } } outputs { key: "classes" value { dtype: DT_STRING tensor_shape { dim { size: 1 } dim { size: 3 } } string_val: "0" string_val: "1" string_val: "2" } }
# 创建 gRPC 连接 import pandas as pd from grpc.beta import implementations import tensorflow as tf from tensorflow_serving.apis import prediction_service_pb2, classification_pb2 #channel = implementations.insecure_channel('127.0.0.1', 8500):8888 channel = implementations.insecure_channel('127.0.0.1', 8888) stub = prediction_service_pb2.beta_create_PredictionService_stub(channel) def _create_feature(v): return tf.train.Feature(float_list=tf.train.FloatList(value=[v])) data1 = {"SepalLength":5.1,"SepalWidth":3.3,"PetalLength":1.7,"PetalWidth":0.5} features1 = {k: _create_feature(v) for k, v in data1.items()} example1 = tf.train.Example(features=tf.train.Features(feature=features1)) data2 = {"SepalLength":1.1,"SepalWidth":1.3,"PetalLength":1.7,"PetalWidth":0.5} features2 = {k: _create_feature(v) for k, v in data2.items()} example2 = tf.train.Example(features=tf.train.Features(feature=features2)) # 获取测试数据集,并转换成 Example 实例。 examples = [example1, example2] # 准备 RPC 请求,指定模型名称。 request = classification_pb2.ClassificationRequest() request.model_spec.name = 'iris' request.input.example_list.examples.extend(examples) # 获取结果 response = stub.Classify(request, 10.0) print(response)
Python代码看起来简单不少,但是我们的线上服务都是Java,所以不好集成的,只能做一些离线的批量预测;
输出如下:
result { classifications { classes { label: "0" score: 0.9997805953025818 } classes { label: "1" score: 0.00021938368445262313 } classes { label: "2" score: 1.382611025668723e-09 } } classifications { classes { label: "0" score: 0.0736534595489502 } classes { label: "1" score: 0.8393719792366028 } classes { label: "2" score: 0.08697459846735 } } } model_spec { name: "iris" version { value: 1 } signature_name: "serving_default" }
个人其实非常喜欢HTTP+JSON接口,完全不用搞这么多grpc这些麻烦的东西,尤其Java的grpc,遇到好多问题好崩溃;
不过号称grpc比http性能好不少,线上只能用grpc。