博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
交叉验证的Java weka实现,并保存和重载模型
阅读量:4357 次
发布时间:2019-06-07

本文共 4761 字,大约阅读时间需要 15 分钟。

我觉得首先有必要简单说说交叉验证,即用只有一个训练集的时候,用一部分数据训练,一部分做测试,当然怎么分配及时不同的方法了。

1)k-folder cross-validation:

k个子集,每个子集均做一次测试集,其余的作为训练集。交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别正确率作为结果。
优点:所有的样本都被作为了训练集和测试集,每个样本都被验证一次。10-folder通常被使用。

2)K * 2 folder cross-validation

是k-folder cross-validation的一个变体,对每一个folder,都平均分成两个集合s0,s1,我们先在集合s0训练用s1测试,然后用s1训练s0测试。
优点是:测试和训练集都足够大,每一个个样本都被作为训练集和测试集。一般使用k=10

3)least-one-out cross-validation(loocv)

假设dataset中有n个样本,那LOOCV也就是n-CV,意思是每个样本单独作为一次测试集,剩余n-1个样本则做为训练集。
优点:
1)每一回合中几乎所有的样本皆用于训练model,因此最接近母体样本的分布,估测所得的generalization error比较可靠。
2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。
但LOOCV的缺点则是计算成本高,为需要建立的models数量与总样本数量相同,当总样本数量相当多时,LOOCV在实作上便有困难,除非每次训练model的速度很快,或是可以用平行化计算减少计算所需的时间。
废话不多说,直接上代码:
关键代码:
 
//直接调用Evaluation即可完成Evaluation eval = null;for (int i = 0; i < 10; i++) {   eval = new Evaluation(Train);   eval.crossValidateModel(m_classifier, Train, 10, new Random(i),args);// 实现交叉验证模型}System.out.println(eval.toSummaryString());// 输出总结信息System.out.println(eval.toClassDetailsString());// 输出分类详细信息System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵
 

这个在网上找了很久,没找到,却偶然一次发现了,其实很简单,只要因为好一点的话,看国外论坛就好多了。
保存模型方法:
 
SerializationHelper.write("LibSVM.model", classifier4);//参数一为模型保存文件,classifier4为要保存的模型
加载模型:
 
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");
全部代码:
package weka_test;import java.io.File;import java.io.IOException;import weka.classifiers.Classifier;import weka.classifiers.trees.J48;import weka.core.Instance;import weka.core.Instances;import weka.core.converters.ArffLoader;import weka.experiment.InstanceQuery;import weka.classifiers.Evaluation;import java.util.Random;public class test {   /**    * oracleInput    * @return data    * @throws Exception*/   public static Instances oracleInput() throws Exception{      InstanceQuery query = new InstanceQuery();      String sql = "SELECT to_char(z.cydate,'yyyy/mm') AS d,sum(z.bcmoney) as c FROM zybc z"+ " WHERE to_char(z.cydate,'yyyy/mm') IS NOT NULL"+ " GROUP BY to_char(z.cydate,'yyyy/mm') ORDER BY to_date(to_char(z.cydate,'yyyy/mm'),'yyyy/mm') ASC";      //System.out.println(sql);      query.setCustomPropsFile(new File("weka/weka_oracle.props"));      query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.133:1521/XE");      query.setUsername("***");      query.setPassword("***");      query.setQuery(sql);      Instances data = query.retrieveInstances();          return data;   }   /**    * mysqlInput    * @return data    * @throws Exception*/   public static Instances mysqlInput() throws Exception{      InstanceQuery query = new InstanceQuery();      String sql = "SELECT * FROM iris";      //System.out.println(sql);      query.setCustomPropsFile(new File("weka/weka_mysql.props"));      query.setDatabaseURL("jdbc:mysql://localhost:3306/test");      query.setUsername("***");      query.setPassword("***");      query.setQuery(sql);      Instances data = query.retrieveInstances();    return data;   }   /**    * @param args    * @throws Exception */   public static void main(String[] args) throws Exception {      // TODO Auto-generated method stub      Classifier m_classifier = new J48();        /*File inputFile = new File("D://Program Files//Weka-3-7//data//iris.arff");//训练语料文件        ArffLoader atf = new ArffLoader();         atf.setFile(inputFile);        Instances instancesTrain = atf.getDataSet(); // 读入训练文件    */       Instances Train = mysqlInput();      Instances Test = mysqlInput();        Test.setClassIndex(4); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数        double sum = Test.numInstances(),//测试语料实例数        right = 0.0f;        Train.setClassIndex(4);        m_classifier.buildClassifier(Train); //训练                  //System.out.println(m_classifier.toString());      //2、利用模型进行预测 int a=0,b=0,c=0,d=0;//记录每个类别的个数,方便计算评价指标for (int i = 0; i < Test.numInstances(); i++) {         double classification = m_classifier.classifyInstance(Train.instance(i));         double classValue = Train.instance(i).classValue();if (classification == 0.0 && classValue == 0.0) {            a++;         } else if (classification == 0.0 && classValue == 1.0) {            b++;         } else if (classification == 1.0 && classValue == 0.0) {            c++;         } else if (classification == 1.0 && classValue == 1.0) {            d++;         }      }      // 3、得出预测效果评测指标      double precision = (double) a / (a + b);      double recall = (double) a / (a + c);      double fMeasure = 2 * precision * recall / (precision + recall);      System.out.println("precision\trecall\tF-Measure");      System.out.println((precision) + "\t\t"+ (recall) + "\t"+ (fMeasure));for(int  i = 0;i
CSDN博客原文
授人以鱼不如授人以渔:
python sklearn数据预处理:
广义线性模型--Generalized Linear Models

转载于:https://www.cnblogs.com/aaronchou820/p/6696254.html

你可能感兴趣的文章
Python学习---django之admin简介
查看>>
个人工作总结11(第二阶段)
查看>>
配置完IDEA开发lua后用idea竟然打不开lua的文件。
查看>>
synchronized、锁、多线程同步的原理是咋样
查看>>
AutoHotKey 快速入门
查看>>
sharepoint 2010批量导入数据
查看>>
Linux学习-Linux历史(总结篇)
查看>>
c++笔记
查看>>
NoSql笔记
查看>>
chromium os系统编译与环境搭建
查看>>
给元素绑定 class
查看>>
如何对iPhone进行屏幕录像
查看>>
网站技术架构
查看>>
maven 配置阿里云仓库
查看>>
合理构建产品形态(一)——谁是目标用户
查看>>
Tomcat服务器与HTTP协议
查看>>
Android studio开发APP的的目录结构
查看>>
VS 2010 Beta2中WPF有哪些改进?
查看>>
一个野生程序员开博日
查看>>
20180528小测
查看>>