我觉得首先有必要简单说说交叉验证,即用只有一个训练集的时候,用一部分数据训练,一部分做测试,当然怎么分配及时不同的方法了。
1)k-folder cross-validation:
k个子集,每个子集均做一次测试集,其余的作为训练集。交叉验证重复k次,每次选择一个子集作为测试集,并将k次的平均交叉验证识别正确率作为结果。优点:所有的样本都被作为了训练集和测试集,每个样本都被验证一次。10-folder通常被使用。
2)K * 2 folder cross-validation
是k-folder cross-validation的一个变体,对每一个folder,都平均分成两个集合s0,s1,我们先在集合s0训练用s1测试,然后用s1训练s0测试。优点是:测试和训练集都足够大,每一个个样本都被作为训练集和测试集。一般使用k=10
3)least-one-out cross-validation(loocv)
假设dataset中有n个样本,那LOOCV也就是n-CV,意思是每个样本单独作为一次测试集,剩余n-1个样本则做为训练集。优点:
1)每一回合中几乎所有的样本皆用于训练model,因此最接近母体样本的分布,估测所得的generalization error比较可靠。
2)实验过程中没有随机因素会影响实验数据,确保实验过程是可以被复制的。
但LOOCV的缺点则是计算成本高,为需要建立的models数量与总样本数量相同,当总样本数量相当多时,LOOCV在实作上便有困难,除非每次训练model的速度很快,或是可以用平行化计算减少计算所需的时间。
关键代码:
//直接调用Evaluation即可完成Evaluation eval = null;for (int i = 0; i < 10; i++) { eval = new Evaluation(Train); eval.crossValidateModel(m_classifier, Train, 10, new Random(i),args);// 实现交叉验证模型}System.out.println(eval.toSummaryString());// 输出总结信息System.out.println(eval.toClassDetailsString());// 输出分类详细信息System.out.println(eval.toMatrixString());// 输出分类的混淆矩阵
这个在网上找了很久,没找到,却偶然一次发现了,其实很简单,只要因为好一点的话,看国外论坛就好多了。 保存模型方法:
SerializationHelper.write("LibSVM.model", classifier4);//参数一为模型保存文件,classifier4为要保存的模型加载模型:
Classifier classifier8 = (Classifier) weka.core.SerializationHelper.read("LibSVM.model");
全部代码:
package weka_test;import java.io.File;import java.io.IOException;import weka.classifiers.Classifier;import weka.classifiers.trees.J48;import weka.core.Instance;import weka.core.Instances;import weka.core.converters.ArffLoader;import weka.experiment.InstanceQuery;import weka.classifiers.Evaluation;import java.util.Random;public class test { /** * oracleInput * @return data * @throws Exception*/ public static Instances oracleInput() throws Exception{ InstanceQuery query = new InstanceQuery(); String sql = "SELECT to_char(z.cydate,'yyyy/mm') AS d,sum(z.bcmoney) as c FROM zybc z"+ " WHERE to_char(z.cydate,'yyyy/mm') IS NOT NULL"+ " GROUP BY to_char(z.cydate,'yyyy/mm') ORDER BY to_date(to_char(z.cydate,'yyyy/mm'),'yyyy/mm') ASC"; //System.out.println(sql); query.setCustomPropsFile(new File("weka/weka_oracle.props")); query.setDatabaseURL("jdbc:oracle:thin:@192.168.2.133:1521/XE"); query.setUsername("***"); query.setPassword("***"); query.setQuery(sql); Instances data = query.retrieveInstances(); return data; } /** * mysqlInput * @return data * @throws Exception*/ public static Instances mysqlInput() throws Exception{ InstanceQuery query = new InstanceQuery(); String sql = "SELECT * FROM iris"; //System.out.println(sql); query.setCustomPropsFile(new File("weka/weka_mysql.props")); query.setDatabaseURL("jdbc:mysql://localhost:3306/test"); query.setUsername("***"); query.setPassword("***"); query.setQuery(sql); Instances data = query.retrieveInstances(); return data; } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { // TODO Auto-generated method stub Classifier m_classifier = new J48(); /*File inputFile = new File("D://Program Files//Weka-3-7//data//iris.arff");//训练语料文件 ArffLoader atf = new ArffLoader(); atf.setFile(inputFile); Instances instancesTrain = atf.getDataSet(); // 读入训练文件 */ Instances Train = mysqlInput(); Instances Test = mysqlInput(); Test.setClassIndex(4); //设置分类属性所在行号(第一行为0号),instancesTest.numAttributes()可以取得属性总数 double sum = Test.numInstances(),//测试语料实例数 right = 0.0f; Train.setClassIndex(4); m_classifier.buildClassifier(Train); //训练 //System.out.println(m_classifier.toString()); //2、利用模型进行预测 int a=0,b=0,c=0,d=0;//记录每个类别的个数,方便计算评价指标for (int i = 0; i < Test.numInstances(); i++) { double classification = m_classifier.classifyInstance(Train.instance(i)); double classValue = Train.instance(i).classValue();if (classification == 0.0 && classValue == 0.0) { a++; } else if (classification == 0.0 && classValue == 1.0) { b++; } else if (classification == 1.0 && classValue == 0.0) { c++; } else if (classification == 1.0 && classValue == 1.0) { d++; } } // 3、得出预测效果评测指标 double precision = (double) a / (a + b); double recall = (double) a / (a + c); double fMeasure = 2 * precision * recall / (precision + recall); System.out.println("precision\trecall\tF-Measure"); System.out.println((precision) + "\t\t"+ (recall) + "\t"+ (fMeasure));for(int i = 0;i
CSDN博客原文
授人以鱼不如授人以渔:
python sklearn数据预处理:
广义线性模型--Generalized Linear Models