公司大部分应用的使用的是JAVA开发,要想使用Python模型非常困难,网上搜索了下,可以先将生成的模型转换为PMML文件后即可在JAVA中直接调用。
以LightGBM为例:
1、将生成的模型导出为txt格式
import pandas as pd from lightgbm import LGBMClassifier iris_df = pd.read_csv("xml/iris.csv") d_x = iris_df.iloc[:, 0:4].values d_y = iris_df.iloc[:, 4].values model = LGBMClassifier( boosting_type='gbdt', objective="multiclass", nthread=8, seed=42) model.n_classes =3 model.fit(d_x,d_y,feature_name=iris_df.columns.tolist()[0:-1]) model.booster_.save_model("xml/lightgbm.txt")
2、 使用工具将txt模型转化为pmml格式
java -jar converter-executable-1.2-SNAPSHOT.jar --lgbm-input lightgbm.txt --pmml-output lightgbm.pmml
3、 在JAVA代码中直接调用
备注,调用前需要引入如下架包: https://github.com/jpmml/jpmml-evaluator ,示例代码:
package com.pmmldemo.test; import java.io.File; import java.io.FileInputStream; import java.io.InputStream; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.dmg.pmml.FieldName; import org.dmg.pmml.PMML; import org.jpmml.evaluator.Evaluator; import org.jpmml.evaluator.FieldValue; import org.jpmml.evaluator.InputField; import org.jpmml.evaluator.ModelEvaluator; import org.jpmml.evaluator.ModelEvaluatorFactory; import org.jpmml.evaluator.TargetField; public class PMMLPrediction { public static void main(String[] args) throws Exception { String pathxml="lightgbm.pmml"; Map<String, Double> map=new HashMap<String, Double>(); //拼装模型参数 map.put("sepal_length", 5.1); map.put("sepal_width", 3.5); map.put("petal_length", 1.4); map.put("petal_width", 0.2); predictLrHeart(map, pathxml); } public static void predictLrHeart(Map<String, Double> irismap,String pathxml)throws Exception { PMML pmml; // 模型导入 File file = new File(pathxml); InputStream inputStream = new FileInputStream(file); try (InputStream is = inputStream) { pmml = org.jpmml.model.PMMLUtil.unmarshal(is); ModelEvaluatorFactory modelEvaluatorFactory = ModelEvaluatorFactory .newInstance(); ModelEvaluator<?> modelEvaluator = modelEvaluatorFactory .newModelEvaluator(pmml); Evaluator evaluator = (Evaluator) modelEvaluator; List<InputField> inputFields = evaluator.getInputFields(); // 过模型的原始特征,从画像中获取数据,作为模型输入 Map<FieldName, FieldValue> arguments = new LinkedHashMap<>(); for (InputField inputField : inputFields) { FieldName inputFieldName = inputField.getName(); Object rawValue = irismap .get(inputFieldName.getValue()); FieldValue inputFieldValue = inputField.prepare(rawValue); arguments.put(inputFieldName, inputFieldValue); } Map<FieldName, ?> results = evaluator.evaluate(arguments); List<TargetField> targetFields = evaluator.getTargetFields(); //对于分类问题等有多个输出。 for (TargetField targetField : targetFields) { FieldName targetFieldName = targetField.getName(); Object targetFieldValue = results.get(targetFieldName); System.err.println("target: " + targetFieldName.getValue() + " value: " + targetFieldValue); } } } }
常见模型转化方方法: