本文主要是介绍使用svm训练mist数据集,希望对大家解决编程问题提供一定的参考价值,需要的开发者们随着小编来一起学习吧!
1.SVM对象的创建和训练
1.1 创建svm
Ptr<ml::SVM> svm = ml::SVM::create();
1.2 svm参数设置
//设置SVM参数
svm->setType(ml::SVM::C_SVC);
svm->setKernel(ml::SVM::RBF);
svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));
或者是
cv::SVM::Params params;
params.svmType = cv::SVM::C_SVC;
params.kernelType = cv::SVM::RBF;
params.termCrit = cv::TermCriteria(cv::TermCriteria::MAX_ITER, 100, 1e-6);
params.C = 1.0;
params.gamma = 0.1;
2. 使用mist数据集进行分类
使用mist数据集进行分类
#include <opencv2/opencv.hpp>
#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <Winsock2.h>
//在对话框左侧选择“配置属性->链接器->输入”,在右侧的“附加依赖项”中添加ws2_32.lib库文件using namespace cv;
using namespace std;//定义存储训练图像和标签的向量
vector<Mat> train_images;
vector<int> train_labels;//定义函数来读取MNIST数据集
//是将一个无符号长整形数从网络字节顺序转换为主机字节顺序
//ntohl()返回一个以主机字节顺序表达的数。
void read_MNIST(string filename, vector<Mat>& vec_images, vector<int>& vec_labels)
{ifstream file(filename, ios::binary);if (file.is_open()){cout << "begin to read MNIST" << endl;int magic_number = 0;int number_of_images = 0;int rows = 0;int cols = 0;file.read((char*)&magic_number, sizeof(magic_number));magic_number = ntohl(magic_number);file.read((char*)&number_of_images, sizeof(number_of_images));number_of_images = ntohl(number_of_images);file.read((char*)&rows, sizeof(rows));rows = ntohl(rows);file.read((char*)&cols, sizeof(cols));cols = ntohl(cols);for (int i = 0; i < number_of_images; ++i){Mat img = Mat::zeros(rows, cols, CV_8UC1);for (int r = 0; r < rows; ++r){for (int c = 0; c < cols; ++c){unsigned char temp = 0;file.read((char*)&temp, sizeof(temp));img.at<uchar>(r, c) = (int)temp;}}int label = 0;file.read((char*)&label, sizeof(label));label = ntohl(label);vec_images.push_back(img);vec_labels.push_back(label);}cout << "read MNIST finish" << endl;}
}int main()
{//读取训练数据string train_images_path = "E:/det/mnist/train-images.idx3-ubyte";string train_labels_path = "E:/det/mnist/train-labels.idx1-ubyte";read_MNIST(train_images_path, train_images, train_labels);//设置SVM参数Ptr<ml::SVM> svm = ml::SVM::create();svm->setType(ml::SVM::C_SVC);svm->setKernel(ml::SVM::RBF);svm->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 1e-6));//将图像转换为特征向量Mat trainData;cout << "images 2 vector" << endl;for (int i = 0; i < train_images.size(); ++i){Mat img;train_images[i].convertTo(img, CV_32FC1);img = img.reshape(1, 1);trainData.push_back(img);}//训练SVM模型cout << "training svm"<< endl;Mat labelsMat(train_labels.size(), 1, CV_32SC1, train_labels.data());svm->train(trainData, ml::ROW_SAMPLE, labelsMat);cout << "training finished" << endl;//保存模型svm->save("svm_model.xml");return 0;
}
这篇关于使用svm训练mist数据集的文章就介绍到这儿,希望我们推荐的文章对编程师们有所帮助!