原文: Predicting Loan Credit Risk using Apache Spark Machine Learning Random Forests
作者:Carol McDonald,MapR解决方案架构师
翻译:KK4SBB
责编:周建丁(zhoujd@csdn.net)
在本文中,我将向大家介绍如何使用Apache Spark的spark.ml库中的随机森林算法来对银行信用贷款的风险做分类预测。 Spark的spark.ml库 基于DataFrame,它提供了大量的接口,帮助用户创建和调优机器学习工作流。结合dataframe使用spark.ml,能够实现模型的智能优化,从而提升模型效果。
分类算法是一类监督式机器学习算法,它根据已知标签的样本(如已经明确交易是否存在欺诈)来预测其它样本所属的类别(如是否属于欺诈性的交易)。分类问题需要一个已经标记过的数据集和预先设计好的特征,然后基于这些信息来学习给新样本打标签。所谓的特征即是一些“是与否”的问题。标签就是这些问题的答案。在下面这个例子里,如果某个动物的行走姿态、游泳姿势和叫声都像鸭子,那么就给它打上“鸭子”的标签。
我们来看一个银行信贷的信用风险例子:
决策树是一种基于输入特征来预测类别或是标签的分类模型。决策树的工作原理是这样的,它在每个节点都需要计算特征在该节点的表达式值,然后基于运算结果选择一个分支通往下一个节点。下图展示了一种用来预测信用风险的决策树模型。每个决策问题就是模型的一个节点,“是”或者“否”的答案是通往子节点的分支。
融合学习算法 结合了多个机器学习的算法,从而得到了效果更好的模型。随机森林是分类和回归问题中一类常用的融合学习方法。此算法基于训练数据的不同子集构建多棵 决策树 ,组合成一个新的模型。预测结果是所有决策树输出的组合,这样能够减少波动,并且提高预测的准确度。对于随机森林分类模型,每棵树的预测结果都视为一张投票。获得投票数最多的类别就是预测的类别。
我们使用 德国人信用度数据集 ,它按照一系列特征属性将人分为信用风险好和坏两类。我们可以获得每个银行贷款申请者的以下信息:
存放德国人信用数据的csv文件格式如下:
1,1,18,4,2,1049,1,2,4,2,1,4,2,21,3,1,1,3,1,1,1 1,1,9,4,0,2799,1,3,2,3,1,2,1,36,3,1,2,3,2,1,1 1,2,12,2,9,841,2,4,2,2,1,4,1,23,3,1,1,2,1,1,1
在这个背景下,我们会构建一个由决策树组成的随机森林模型来预测是否守信用的标签/类别,基于以下特征:
本教程将使用Spark 1.6.1
按照教程指示,登录MapR沙箱,用户名为user01,密码为mapr。将样本数据文件复制到你的沙箱主目录下/user/user01 using scp。(注意,你可能需要先更新Spark的版本)打开spark shell:
$spark-shell --masterlocal[1]
首先,我们需要引入机器学习相关的包。
importorg.apache.spark.ml.classification.RandomForestClassifier importorg.apache.spark.ml.evaluation.BinaryClassificationEvaluator importorg.apache.spark.ml.feature.StringIndexer importorg.apache.spark.ml.feature.VectorAssembler importsqlContext.implicits._ importsqlContext._ importorg.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator } importorg.apache.spark.ml.{ Pipeline, PipelineStage }
我们用一个Scala的case类来定义Credit的属性,对应于csv文件中的一行。
<spanclass="hljs-comment">// define the Credit Schema</span> <spanclass="hljs-class"><spanclass="hljs-keyword">case</span> <spanclass="hljs-keyword">class</span> <spanclass="hljs-title">Credit</span><spanclass="hljs-params">( creditability: Double, balance: Double, duration: Double, history: Double, purpose: Double, amount: Double, savings: Double, employment: Double, instPercent: Double, sexMarried: Double, guarantors: Double, residenceDuration: Double, assets: Double, age: Double, concCredit: Double, apartment: Double, credits: Double, occupation: Double, dependents: Double, hasPhone: Double, foreign: Double )</span></span>
下面的函数解析一行数据文件,将值存入Credit类中。类别的索引值减去了1,因此起始索引值为0.
<spanclass="hljs-comment"> // function to create a Credit class from an Array of Double</span> defparseCredit(<spanclass="hljs-built_in">line</span>: Array[Double]): Credit = { Credit( <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">0</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">1</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">2</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">3</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">4</span>) , <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">5</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">6</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">7</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">8</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">9</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">10</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">11</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">12</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">13</span>), <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">14</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">15</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">16</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">17</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">18</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">19</span>) - <spanclass="hljs-number">1</span>, <spanclass="hljs-built_in">line</span>(<spanclass="hljs-number">20</span>) - <spanclass="hljs-number">1</span> ) } <spanclass="hljs-comment"> // function to transform an RDD of Strings into an RDD of Double</span> defparseRDD(rdd: RDD[String]): RDD[Array[Double]] = { rdd.map(_.<spanclass="hljs-built_in">split</span>(<spanclass="hljs-string">","</span>)).map(_.map(_.toDouble)) }
接下去,我们导入germancredit.csv文件中的数据,存为一个String类型的RDD。然后我们对RDD做map操作,将RDD中的每个字符串经过ParseRDDR函数的映射,转换为一个Double类型的数组。紧接着是另一个map操作,使用ParseCredit函数,将每个Double类型的RDD转换为Credit对象。toDF()函数将Array[[Credit]]类型的RDD转为一个Credit类的Dataframe。
// <span class="hljs-operator"><span class="hljs-keyword">load</span> the data <span class="hljs-keyword">into</span> a RDD valcreditDF= parseRDD(sc.textFile(<spanclass="hljs-string">"germancredit.csv"</span>)).map(parseCredit).toDF().cache() creditDF.registerTempTable(<spanclass="hljs-string">"credit"</span>) DataFrame的printSchema()函数将各个字段含义以树状的形式打印到控制台输出。 // Return the <span class="hljs-keyword">schema</span> <span class="hljs-keyword">of</span> this DataFrame creditDF.printSchema root |-- creditability: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- balance: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- duration: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- history: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- purpose: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- amount: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- savings: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- employment: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- instPercent: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- sexMarried: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- guarantors: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- residenceDuration: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- assets: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- age: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- concCredit: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- apartment: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- credits: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- occupation: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- dependents: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- hasPhone: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) |-- <spanclass="hljs-keyword">foreign</span>: <spanclass="hljs-keyword">double</span> (nullable = <spanclass="hljs-keyword">false</span>) // Display the top <span class="hljs-number">20</span> <span class="hljs-keyword">rows</span> <span class="hljs-keyword">of</span> DataFrame creditDF.<spanclass="hljs-keyword">show</span> +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|<spanclass="hljs-keyword">foreign</span>| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">9.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2799.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">12.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">9.0</span>| <spanclass="hljs-number">841.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">23.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">12.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2122.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">39.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">12.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2171.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">38.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">10.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">2241.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">48.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">8.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3398.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">39.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">6.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">1361.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">40.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1098.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">65.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">24.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">3758.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">23.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">11.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3905.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">30.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">6187.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">24.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">6.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1957.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">31.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">48.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">10.0</span>|<spanclass="hljs-number">7582.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">31.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">1936.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">23.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">6.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">2647.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">44.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">11.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">3939.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">40.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">3213.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">25.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">36.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">3.0</span>|<spanclass="hljs-number">2337.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">36.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">11.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">7228.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">39.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+ </span>
dataframe初始化之后,你可以用SQL命令查询数据了。下面是一些使用Scala DataFrame接口查询数据的例子:
计算数值型数据的统计信息,包括计数、均值、标准差、最小值和最大值。
<spanclass="hljs-comment"> // computes statistics for balance </span> creditDF.describe(<spanclass="hljs-string">"balance"</span>).show +<spanclass="hljs-comment">-------+-----------------+</span> |summary| balance| +<spanclass="hljs-comment">-------+-----------------+</span> | count| <spanclass="hljs-number">1000</span>| | mean| <spanclass="hljs-number">1.577</span>| | stddev|<spanclass="hljs-number">1.257637727110893</span>| | <spanclass="hljs-built_in">min</span>| <spanclass="hljs-number">0.0</span>| | <spanclass="hljs-built_in">max</span>| <spanclass="hljs-number">3.0</span>| +<spanclass="hljs-comment">-------+-----------------+</span> <spanclass="hljs-comment"> // compute the avg balance by creditability (the label) </span> creditDF.groupBy(<spanclass="hljs-string">"creditability"</span>).<spanclass="hljs-built_in">avg</span>(<spanclass="hljs-string">"balance"</span>).show +<spanclass="hljs-comment">-------------+------------------+</span> |creditability| <spanclass="hljs-built_in">avg</span>(balance)| +<spanclass="hljs-comment">-------------+------------------+</span> | <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">1.8657142857142857</span>| | <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">0.9033333333333333</span>| +<spanclass="hljs-comment">-------------+------------------+</span>
你可以用某个表名将DataFrame注册为一张临时表,然后用SQLContext提供的sql方法执行SQL命令。下面是几个用sqlContext查询的例子:
sqlContext.sql("<span class="hljs-operator"><span class="hljs-keyword">SELECT</span> creditability, <span class="hljs-aggregate">avg</span>(balance) <span class="hljs-keyword">as</span> avgbalance, <span class="hljs-aggregate">avg</span>(amount) <span class="hljs-keyword">as</span> avgamt, <span class="hljs-aggregate">avg</span>(duration) <span class="hljs-keyword">as</span> avgdur <span class="hljs-keyword">FROM</span> credit <span class="hljs-keyword">GROUP</span> <span class="hljs-keyword">BY</span> creditability <span class="hljs-string">").show +-------------+------------------+------------------+------------------+ |creditability| avgbalance| avgamt| avgdur| +-------------+------------------+------------------+------------------+ | 1.0|1.8657142857142857| 2985.442857142857|19.207142857142856| | 0.0|0.9033333333333333|3938.1266666666666| 24.86| +-------------+------------------+------------------+------------------+</span></span>
为了构建一个分类模型,你首先需要提取对分类最有帮助的特征。在德国人信用度的数据集里,每条样本用两个类别来标记——1(可信)和0(不可信)。
每个样本的特征包括以下的字段:
下图中,用VectorAssembler方法将每个维度的特征都做变换,返回一个新的dataframe。
//define the feature columns to put in the feature vector valfeatureCols = <spanclass="hljs-class">Array</span>(<spanclass="hljs-comment">"balance"</span>, <spanclass="hljs-comment">"duration"</span>, <spanclass="hljs-comment">"history"</span>, <spanclass="hljs-comment">"purpose"</span>, <spanclass="hljs-comment">"amount"</span>, <spanclass="hljs-comment">"savings"</span>, <spanclass="hljs-comment">"employment"</span>, <spanclass="hljs-comment">"instPercent"</span>, <spanclass="hljs-comment">"sexMarried"</span>, <spanclass="hljs-comment">"guarantors"</span>, <spanclass="hljs-comment">"residenceDuration"</span>, <spanclass="hljs-comment">"assets"</span>, <spanclass="hljs-comment">"age"</span>, <spanclass="hljs-comment">"concCredit"</span>, <spanclass="hljs-comment">"apartment"</span>, <spanclass="hljs-comment">"credits"</span>, <spanclass="hljs-comment">"occupation"</span>, <spanclass="hljs-comment">"dependents"</span>, <spanclass="hljs-comment">"hasPhone"</span>, <spanclass="hljs-comment">"foreign"</span> ) //set the input and output column names valassembler = new <spanclass="hljs-class">VectorAssembler</span>().setInputCols(featureCols).setOutputCol(<spanclass="hljs-comment">"features"</span>) //return a dataframe with all of the feature columns in a vector column valdf2 = assembler.transform( creditDF) // the transform method produced a new <span class="hljs-method">column:</span> features. df2.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ <spanclass="hljs-localvars">|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features|</span> +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+ | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">2</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,...|
接着,我们使用StringIndexer方法返回一个Dataframe,增加了信用度这一列作为标签。
// <span class="hljs-class">Create</span> a label column with the <span class="hljs-class">StringIndexer</span> vallabelIndexer = new <spanclass="hljs-class">StringIndexer</span>().setInputCol(<spanclass="hljs-comment">"creditability"</span>).setOutputCol(<spanclass="hljs-comment">"label"</span>) valdf3 = labelIndexer.fit(df2).transform(df2) // the transform method produced a new <span class="hljs-method">column:</span> label. df3.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ <spanclass="hljs-localvars">|creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features|label|</span> +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+ | <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">18.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>|<spanclass="hljs-number">1049.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">1.0</span>|<spanclass="hljs-number">21.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">2</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,...| <spanclass="hljs-number">0.0</span>|
下图中,数据集被分为训练数据和测试数据两个部分,70%的数据用来训练模型,30%的数据用来测试模型。
<spanclass="hljs-comment">// split the dataframe into training and test data</span> valsplitSeed = <spanclass="hljs-number">5043</span> val <spanclass="hljs-built_in">Array</span>(trainingData, testData) = df3.randomSplit(<spanclass="hljs-built_in">Array</span>(<spanclass="hljs-number">0.7</span>, <spanclass="hljs-number">0.3</span>), splitSeed)
接着,我们按照下列参数训练一个随机森林分类器:
模型的训练过程就是将输入特征和这些特征对应的样本标签相关联的过程。
// create the classifier, <span class="hljs-keyword">set</span> parameters <span class="hljs-flow">for</span> training valclassifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(<spanclass="hljs-number">3</span>).setNumTrees(<spanclass="hljs-number">20</span>).setFeatureSubsetStrategy("auto").setSeed(<spanclass="hljs-number">5043</span>) // use the random forest classifier to train (fit) the model valmodel = classifier.fit(trainingData) // print out the random forest trees model.toDebugString res20: String = res5: String = "RandomForestClassificationModel (uid=rfc_6c4ceb92ba78) with <span class="hljs-number">20</span> trees Tree <span class="hljs-number">0</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>): <span class="hljs-flow">If</span> (feature <span class="hljs-number">0</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">10</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">6</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">6</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">10</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">12</span> <= <span class="hljs-number">63</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">12</span> > <span class="hljs-number">63</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">0</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">13</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">3</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">3</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">13</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">7</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">7</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> Tree <span class="hljs-number">1</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>): <span class="hljs-flow">If</span> (feature <span class="hljs-number">2</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">15</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">15</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">2</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">12</span> <= <span class="hljs-number">31</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">5</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">5</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">12</span> > <span class="hljs-number">31</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">3</span> <= <span class="hljs-number">4</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">3</span> > <span class="hljs-number">4</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> Tree <span class="hljs-number">2</span> (weight <span class="hljs-number">1</span>.<span class="hljs-number">0</span>): <span class="hljs-flow">If</span> (feature <span class="hljs-number">8</span> <= <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">6</span> <= <span class="hljs-number">2</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">4</span> <= <span class="hljs-number">10875</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">4</span> > <span class="hljs-number">10875</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">6</span> > <span class="hljs-number">2</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">1</span> <= <span class="hljs-number">36</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">1</span> > <span class="hljs-number">36</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">8</span> > <span class="hljs-number">1</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">5</span> <= <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">4</span> <= <span class="hljs-number">4113</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">4</span> > <span class="hljs-number">4113</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">1</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">5</span> > <span class="hljs-number">0</span>.<span class="hljs-number">0</span>) <span class="hljs-flow">If</span> (feature <span class="hljs-number">11</span> <= <span class="hljs-number">2</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> <span class="hljs-flow">Else</span> (feature <span class="hljs-number">11</span> > <span class="hljs-number">2</span>.<span class="hljs-number">0</span>) Predict: <span class="hljs-number">0</span>.<span class="hljs-number">0</span> Tree <span class="hljs-number">3</span> ...
接下来,我们对测试数据进行预测。
// run the model on test features to get predictions valpredictions = model.transform(testData) //As you can see, the previous model transform produced a new columns: rawPrediction, probablity and prediction. predictions.show +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+ |creditability|balance|duration|history|purpose|amount|savings|employment|instPercent|sexMarried|guarantors|residenceDuration|assets| age|concCredit|apartment|credits|occupation|dependents|hasPhone|foreign| features|label| rawPrediction| probability|prediction| +-------------+-------+--------+-------+-------+------+-------+----------+-----------+----------+----------+-----------------+------+----+----------+---------+-------+----------+----------+--------+-------+--------------------+-----+--------------------+--------------------+----------+ | <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">12.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">5.0</span>|<spanclass="hljs-number">1108.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">3.0</span>| <spanclass="hljs-number">4.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>|<spanclass="hljs-number">28.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">1.0</span>| <spanclass="hljs-number">2.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>| <spanclass="hljs-number">0.0</span>|(<spanclass="hljs-number">20</span>,[<spanclass="hljs-number">1</span>,<spanclass="hljs-number">3</span>,<spanclass="hljs-number">4</span>,<spanclass="hljs-number">6</span>,<spanclass="hljs-number">7</span>,<spanclass="hljs-number">8</span>,<spanclass="hljs-keyword">...</span>| <spanclass="hljs-number">1.0</span>|[<spanclass="hljs-number">14.1964586927573</span><spanclass="hljs-keyword">...</span>|[<spanclass="hljs-number">0.70982293463786</span><spanclass="hljs-keyword">...</span>| <spanclass="hljs-number">0.0</span>|
然后,我们用BinaryClassificationEvaluator评估预测的效果,它将预测结果与样本的实际标签相比较,返回一个准确度指标(ROC曲线所覆盖的面积)。本例子中,AUC达到78%。
// <span class="hljs-operator"><span class="hljs-keyword">create</span> an Evaluator <span class="hljs-keyword">for</span> binary classification, which expects two <span class="hljs-keyword">input</span> columns: rawPrediction <span class="hljs-keyword">and</span> label. valevaluator = new BinaryClassificationEvaluator().setLabelCol(<spanclass="hljs-string">"label"</span>) // Evaluates predictions <span class="hljs-keyword">and</span> returns a scalar metric areaUnderROC(larger <span class="hljs-keyword">is</span> better). valaccuracy = evaluator.evaluate(predictions) accuracy: <spanclass="hljs-keyword">Double</span> = <spanclass="hljs-number">0.7824906081835722</span></span>
我们接着用管道来训练模型,可能会取得更好的效果。管道采取了一种简单的方式来比较各种不同组合的参数的效果,这个方法称为网格搜索法(grid search),你先设置好待测试的参数,MLLib就会自动完成这些参数的不同组合。管道搭建了一条工作流,一次性完成了整个模型的调优,而不是独立对每个参数进行调优。
下面我们就用ParamGridBuilder工具来构建参数网格。
// We <span class="hljs-keyword">use</span> a ParamGridBuilder <span class="hljs-keyword">to</span> construct a grid <span class="hljs-keyword">of</span> parameters <span class="hljs-keyword">to</span> search over valparamGrid = <spanclass="hljs-keyword">new</span> ParamGridBuilder() .addGrid(classifier.maxBins, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-number">25</span>, <spanclass="hljs-number">28</span>, <spanclass="hljs-number">31</span>)) .addGrid(classifier.maxDepth, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-number">4</span>, <spanclass="hljs-number">6</span>, <spanclass="hljs-number">8</span>)) .addGrid(classifier.impurity, <spanclass="hljs-keyword">Array</span>(<spanclass="hljs-string">"entropy"</span>, <spanclass="hljs-string">"gini"</span>)) .build()
创建并完成一条管道。一条管道由一系列stage组成,每个stage相当于一个Estimator或是Transformer。
valsteps: <spanclass="hljs-built_in">Array</span>[PipelineStage] = <spanclass="hljs-built_in">Array</span>(classifier) valpipeline = <spanclass="hljs-keyword">new</span> Pipeline().setStages(steps)
我们用CrossValidator类来完成模型筛选。CrossValidator类使用一个Estimator类,一组ParamMaps类和一个Evaluator类。注意,使用CrossValidator类的开销很大。
// Evaluate model on test instances <span class="hljs-keyword">and</span> compute test error valevaluator = new BinaryClassificationEvaluator() <spanclass="hljs-preprocessor">.setLabelCol</span>(<spanclass="hljs-string">"label"</span>) valcv = new CrossValidator() <spanclass="hljs-preprocessor">.setEstimator</span>(pipeline) <spanclass="hljs-preprocessor">.setEvaluator</span>(evaluator) <spanclass="hljs-preprocessor">.setEstimatorParamMaps</span>(paramGrid) <spanclass="hljs-preprocessor">.setNumFolds</span>(<spanclass="hljs-number">10</span>)
管道在参数网格上不断地爬行,自动完成了模型优化的过程:对于每个ParamMap类,CrossValidator训练得到一个Estimator,然后用Evaluator来评价结果,然后用最好的ParamMap和整个数据集来训练最优的Estimator。
// When fit <span class="hljs-keyword">is</span> called, the stages are executed <span class="hljs-keyword">in</span> order. // Fit will run cross-validation, <span class="hljs-keyword">and</span> choose the best <span class="hljs-keyword">set</span> <span class="hljs-keyword">of</span> parameters //The fitted model <span class="hljs-keyword">from</span> a Pipeline <span class="hljs-keyword">is</span> an PipelineModel, which consists <span class="hljs-keyword">of</span> fitted models <span class="hljs-keyword">and</span> transformers valpipelineFittedModel = cv.fit(trainingData)
现在,我们可以用管道训练得到的最优模型进行预测,将预测结果与标签做比较。预测结果取得了82%的准确率,相比之前78%的准确率有提高。
// <span class="hljs-keyword">call</span> tranform to make predictions on test data. The fitted model will use the best model found valpredictions = pipelineFittedModel<spanclass="hljs-preprocessor">.transform</span>(testData) valaccuracy = evaluator<spanclass="hljs-preprocessor">.evaluate</span>(predictions) Double = <spanclass="hljs-number">0.8204386232104784</span> valrm2 = new RegressionMetrics( predictions<spanclass="hljs-preprocessor">.select</span>(<spanclass="hljs-string">"prediction"</span>, <spanclass="hljs-string">"label"</span>)<spanclass="hljs-preprocessor">.rdd</span><spanclass="hljs-preprocessor">.map</span>(<spanclass="hljs-built_in">x</span> => (<spanclass="hljs-built_in">x</span>(<spanclass="hljs-number">0</span>)<spanclass="hljs-preprocessor">.asInstanceOf</span>[Double], <spanclass="hljs-built_in">x</span>(<spanclass="hljs-number">1</span>)<spanclass="hljs-preprocessor">.asInstanceOf</span>[Double]))) println(<spanclass="hljs-string">"MSE: "</span> + rm2<spanclass="hljs-preprocessor">.meanSquaredError</span>) println(<spanclass="hljs-string">"MAE: "</span> + rm2<spanclass="hljs-preprocessor">.meanAbsoluteError</span>) println(<spanclass="hljs-string">"RMSE Squared: "</span> + rm2<spanclass="hljs-preprocessor">.rootMeanSquaredError</span>) println(<spanclass="hljs-string">"R Squared: "</span> + rm2<spanclass="hljs-preprocessor">.r</span>2) println(<spanclass="hljs-string">"Explained Variance: "</span> + rm2<spanclass="hljs-preprocessor">.explainedVariance</span> + <spanclass="hljs-string">"/n"</span>) MSE: <spanclass="hljs-number">0.2575250836120402</span> MAE: <spanclass="hljs-number">0.25752508361204013</span> RMSESquared: <spanclass="hljs-number">0.5074692932700856</span> R Squared: -<spanclass="hljs-number">0.1687988628287138</span> ExplainedVariance: <spanclass="hljs-number">0.15466269952237702</span>