上一篇把TF-Serving源码编译后,就可以修改代码把TF-Serving嵌入SpringCloud了。
在tensorflow_serving/apis 下添加一个health.proto:
syntax = "proto3"; option cc_enable_arenas = true; package tensorflow.serving; message HealthResponse { // health status, return UP string status = 1; }
在apis/BUILD添加该proto的编译:
serving_proto_library( name = "health_proto", srcs = ["health.proto"], cc_api_version = 2, deps = [ ":model_proto", "//tensorflow_serving/util:status_proto", ], ) serving_proto_library_py( name = "health_proto_py_pb2", srcs = ["health.proto"], proto_library = "health_proto", deps = [ ":model_proto_py_pb2", "//tensorflow_serving/util:status_proto_py_pb2", ], )
在bazel编译文件tensorflow_serving/model_server/BUILD 的http_rest_api_handler目标中引入刚才定义的 health_proto :
cc_library( name = "http_rest_api_handler", srcs = ["http_rest_api_handler.cc"], hdrs = ["http_rest_api_handler.h"], visibility = ["//visibility:public"], deps = [ ":get_model_status_impl", ":server_core", "//tensorflow_serving/apis:model_proto", "//tensorflow_serving/apis:predict_proto", "//tensorflow_serving/apis:health_proto", "//tensorflow_serving/core:servable_handle", "//tensorflow_serving/servables/tensorflow:classification_service", "//tensorflow_serving/servables/tensorflow:get_model_metadata_impl", "//tensorflow_serving/servables/tensorflow:predict_impl", "//tensorflow_serving/servables/tensorflow:regression_service", "//tensorflow_serving/util:json_tensor", "@com_google_absl//absl/strings", "@com_google_absl//absl/time", "@com_google_absl//absl/types:optional", "@com_googlesource_code_re2//:re2", "@org_tensorflow//tensorflow/cc/saved_model:loader", "@org_tensorflow//tensorflow/cc/saved_model:signature_constants", "@org_tensorflow//tensorflow/core:lib", "@org_tensorflow//tensorflow/core:protos_all_cc", ], )
在http的处理类头文件http_rest_api_handler.h中添加方法和regex:
Status GetHealth(string* output); ... const RE2 health_api_regex_;
在对应的http_rest_api_handler.cc中实现:
#include "tensorflow_serving/apis/health.pb.h"
HttpRestApiHandler构造函数中初始化health_api_regex_:注意,因为转到http处理函数之前有一个请求path的验证,需要有v1在path中,所以这里也加了v1。
health_api_regex_( R"((?i)/v1/health)")
主处理函数ProcessRequest中添加health的处理方法:
Status HttpRestApiHandler::ProcessRequest( ... if (http_method == "POST" && RE2::FullMatch(string(request_path), prediction_api_regex_, &model_name, &model_version_str, &method)) { ... } else if (http_method == "GET" && RE2::FullMatch(string(request_path), modelstatus_api_regex_, &model_name, &model_version_str, &model_subresource)) { ... } else if (http_method == "GET" && RE2::FullMatch(string(request_path), health_api_regex_)) { status = GetHealth(output); } if (!status.ok()) { FillJsonErrorMsg(status.error_message(), output); } return status; }
GetHealth方法实现:
Status HttpRestApiHandler::GetHealth(string* output) { HealthResponse response; response.set_status("UP"); JsonPrintOptions opts; opts.add_whitespace = true; opts.always_print_primitive_fields = true; // Note this is protobuf::util::Status (not TF Status) object. const auto& status = MessageToJsonString(response, output, opts); if (!status.ok()) { return errors::Internal("Failed to convert proto to json. Error: ", status.ToString()); } return Status::OK(); }
编译起服务, curl一下health接口
➜ ~ curl http://localhost:8501/v1/health { "status": "UP" }
使用TF-Serving自带的模型./tensorflow-serving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu起服务,测试在线预测接口:
./bazel-bin/tensorflow_serving/model_servers/tensorflow_model_server --rest_api_port=8501 --port=8502 --model_name=half_plus_two --model_base_path=./tensorflow-serving/serving/tensorflow_serving/servables/tensorflow/testdata/saved_model_half_plus_two_cpu
➜ Code curl -d '{"instances": [1.0, 2.0, 5.0]}' -X POST http://localhost:8501/v1/models/half_plus_two:predict { "predictions": [2.5, 3.0, 4.5 ] }%
说明在线预测接口可用。
与第一篇一致,只是替换了Django为TF-Serving。
pom文件:
<?xml version="1.0" encoding="UTF-8"?> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd"> <modelVersion>4.0.0</modelVersion> <parent> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-parent</artifactId> <version>2.2.0.RELEASE</version> <relativePath/> <!-- lookup parent from repository --> </parent> <groupId>com.example</groupId> <artifactId>cloud</artifactId> <version>0.0.1-SNAPSHOT</version> <name>cloud</name> <description>Demo project for Spring Boot</description> <properties> <java.version>1.8</java.version> <spring-cloud.version>Hoxton.RC1</spring-cloud.version> </properties> <dependencies> <dependency> <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-starter-netflix-eureka-server</artifactId> </dependency> <dependency> <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-starter-netflix-eureka-client</artifactId> </dependency> <dependency> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-starter-test</artifactId> <scope>test</scope> <exclusions> <exclusion> <groupId>org.junit.vintage</groupId> <artifactId>junit-vintage-engine</artifactId> </exclusion> </exclusions> </dependency> <dependency> <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-netflix-sidecar</artifactId> <!-- <version>1.2.4.RELEASE</version><!–具体版本可自选–>--> </dependency> </dependencies> <dependencyManagement> <dependencies> <dependency> <groupId>org.springframework.cloud</groupId> <artifactId>spring-cloud-dependencies</artifactId> <version>${spring-cloud.version}</version> <type>pom</type> <scope>import</scope> </dependency> </dependencies> </dependencyManagement> <build> <plugins> <plugin> <groupId>org.springframework.boot</groupId> <artifactId>spring-boot-maven-plugin</artifactId> </plugin> </plugins> </build> <repositories> <repository> <id>spring-milestones</id> <name>Spring Milestones</name> <url>https://repo.spring.io/milestone</url> </repository> </repositories> </project>
application.properties文件:
eureka.client.serviceUrl.defaultZone=http://localhost:8761/eureka/ ##Sidecar注册到Eureka注册中心的端口 server.port=8667 ## 服务的名称,在Eureka注册中心上会显示此名称(在生产环境中,此名称最好与Sidecar所代理服务的名称保持一致) spring.application.name=tfserving ##Sidecar监听的非JVM服务端口 sidecar.port=8501 ##非JVM服务需要实现该接口,[响应结果](#原有服务实现健康检查API)后面会给出注册配置 sidecar.health-uri=http://localhost:8501/v1/health #hystrix.command.default.execution.timeout.enabled: false hystrix.metrics.enabled=false
Application方法:
package com.example.cloud; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.cloud.netflix.sidecar.EnableSidecar; import org.springframework.web.bind.annotation.RestController; @SpringBootApplication @RestController @EnableSidecar public class CloudApplication { public static void main(String[] args) { SpringApplication.run(CloudApplication.class, args); } }
起来后可以看到注册上了Eureka:
还是使用之前的客户端,加上请求TF-Serving接口
添加一个Request的结构体PredictRequestJson:
package com.example.callpython; import java.io.Serializable; import java.util.List; public class PredictRequestJson<T> implements Serializable { private List<T> instances; private String signature_name; public List<T> getInstances() { return instances; } public void setInstances(List<T> instances) { this.instances = instances; } public String getSignature_name() { return signature_name; } public void setSignature_name(String signature_name) { this.signature_name = signature_name; } }
添加一个Feign接口,使用刚才定义的Request:
package com.example.callpython; import org.springframework.cloud.openfeign.FeignClient; import org.springframework.validation.annotation.Validated; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestMapping; import org.springframework.web.bind.annotation.RequestMethod; @FeignClient(name = "tfserving") public interface TFServingFeign { @RequestMapping(value = "/v1/models/half_plus_two:predict", method = RequestMethod.POST) String getPredictResult(@Validated @RequestBody PredictRequestJson requestJson) throws Exception; }
添加一个Controller函数,伪造数据调用Feign:
@RequestMapping("tfserving") public String requestTFServing() { try { PredictRequestJson requestJson = new PredictRequestJson(); List<Double> integerList = new ArrayList<>(); integerList.add(1.0); integerList.add(2.0); integerList.add(5.1); requestJson.setInstances(integerList); requestJson.setSignature_name("serving_default"); return tfServingFeign.getPredictResult(requestJson); } catch (Exception e) { System.out.println(e.getMessage()); } return "exception or timeout"; }
起服务后请求: http://localhost:8700/tfserving 返回
{ predictions: [ 2.5, 3, 4.55 ] }