[译]使用XGBoost和C#解决虹膜分类
By robot-v1.0
本文链接 https://www.kyfws.com/ai/solving-iris-classification-using-xgboost-and-csha-zh/
版权声明 本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!
- 7 分钟阅读 - 3081 个词 阅读量 0使用XGBoost和C#解决虹膜分类(译文)
原文地址:https://www.codeproject.com/Articles/5182988/Solving-Iris-classification-using-XGBoost-and-Csha
原文作者:Sau002
译文由本站 robot-v1.0 翻译
目录(Table of contents)
介绍(Introduction)
图片来源:维基百科(Image source:Wikipedia)
在本文中,我演示了如何使用(In this article I have demonstrated how to use the) 流行的C#包装器(C# wrapper of the popular) XGBoost非托管库.(XGBoost unmanaged library.) XGBoost(XGBoost) 代表"极端梯度增强".我用过著名的(stands for “Extreme Gradient Boosting”. I have used the famous) 虹膜(IRIS) 训练和测试模型的数据集.我的目标是分享有关如何在C#应用程序中嵌入机器学习算法(如极限梯度增强)的经验.在继续前进之前,我必须感谢XGBoost非托管库的开发人员以及.NET包装库的开发人员.(dataset to train and test a model. My objective was to share my learnings of how to embed a machine learning algorithm like extreme gradient boosting in your C# application. Before I move forward I must extend my gratitude to the developers of the XGBoost unmanaged library and to the developers of .NET wrapper library.)
背景(Background)
本文希望用户对以下方面的中级知识感到满意:(This article expects the user to be comfortable with an intermediate knowledge of the following:)
- 决策树算法(Decision tree algorithm)
- 梯度提升算法(Gradient boosting algorithm)
- 数据标准化(Data normalization)
- C#(C#) 本文和随附的代码不提供决策树和梯度提升算法的深入教程.我提供了YouTube培训视频的链接,我认为这些视频具有重要的教育意义.(This article and the accompanying code refrains from providing an indepth tutorial of decision trees and gradient boosting algorithms. I have provided links to Youtube training videos which in my opinion are of immense educational importance.)
梯度Boost分类算法概述(Overview of Gradient Boost Classification algorithm)
决策树简介(StatQuest)(Intro to decision trees (StatQuest))
构造决策树时了解Gini索引(Understanding Gini index while constructing a decision tree)
AdaBoost简介(Intro to AdaBoost)
梯度提升简介(Intro to Gradient Boost)
XGBoost库(C#)(XGBoost library (C#))
托管包装(Managed wrapper)
原始XGBoost库的C/C ++源代码可在(The C/C++ source code for the original XGBoost library is available on) Github(Github) .您可以找到Windows的构建说明.多亏了(. You can find build instructions for Windows. Thanks to the efforts of) 图片网(PicNet) ,我们可以跳过编译非托管源的步骤,而直接跳转到托管包装器.(, we can skip the step of compiling the unmanaged sources and directly jump to the managed wrapper.)
简单线性分类问题(Simple linear classification problem)
我们将进行一个简单的练习,在该练习中,我们将训练一个模型来对线性可分离的2个点进行分类(We will carry out a simple exercise where we will train a model to classify 2 clusters of points which are nicely linearly separable)
/// <summary>
/// Two classes of vectors - Class-Blue and Class-Red
/// Class-Blue - The vectors are centered around the point (+0.5,+0.5) and label value=1
/// Class-Red - The vectors are centered around the point (-0.5,-0.5) and label value=0
/// <summary>
[TestMethod]
public void LinearClassification1()
{
var xgb = new XGBoost.XGBClassifier();
float[][] vectorsTrain = new float[][]
{
new[] {0.5f,0.5f},
new[] {0.6f,0.6f},
new[] {0.6f,0.4f},
new[] {0.4f,0.6f},
new[] {0.4f,0.4f},
new[] {-0.5f,-0.5f},
new[] {-0.6f,-0.6f},
new[] {-0.6f,-0.4f},
new[] {-0.4f,-0.6f},
new[] {-0.4f,-0.4f},
};
var lablesTrain = new[]
{
1.0f,
1.0f,
1.0f,
1.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
0.0f,
};
///
/// Ensure count of training labels=count of training vectors
///
Assert.AreEqual(vectorsTrain.Length, lablesTrain.Length);
///
/// Train the model
///
xgb.Fit(vectorsTrain, lablesTrain);
///
/// Test the model using test vectors
///
float[][] vectorsTest = new float[][]
{
new[] {0.55f,0.55f},
new[] {0.55f,0.45f},
new[] {0.45f,0.55f},
new[] {0.45f,0.45f},
new[] {-0.55f,-0.55f},
new[] {-0.55f,-0.45f},
new[] {-0.45f,-0.55f},
new[] {-0.45f,-0.45f},
};
var labelsTestExpected = new[]
{
1.0f,
1.0f,
1.0f,
1.0f,
0.0f,
0.0f,
0.0f,
0.0f,
};
float[] labelsTestPredicted = xgb.Predict(vectorsTest);
///
/// Verify that predicted labels match the expected labels
///
CollectionAssert.AreEqual(labelsTestPredicted, labelsTestExpected);
}
实现异或逻辑(Implementing XOR logic)
XOR逻辑比线性分类更复杂.数据点不能直接线性分离.(The XOR logic is more complex than the a linear classification. The data points are not directly linearly separable.)
XOR真值表(XOR Truth table)
X | Y | OUTPUT
--------------
1 | 0 | 1
--------------
0 | 1 | 1
--------------
0 | 0 | 0
--------------
1 | 1 | 0
--------------
样例代码(Sample code)
[TestMethod]
public void TestMethod1()
{
var xgb = new XGBoost.XGBClassifier();
///
/// Generate training vectors
///
int countTrainingPoints = 50;
entity.XGBArray trainClass_0_1 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.0, 0.5,
0.5, 1.0, 1.0);//0,1
entity.XGBArray trainClass_1_0 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.5, 1.0,
0.0, 0.5, 1.0);//1,0
entity.XGBArray trainClass_0_0 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.0, 0.5,
0.0, 0.5, 0.0);//0,0
entity.XGBArray trainClass_1_1 = Util.GenerateRandom2dPoints(countTrainingPoints / 2,
0.5, 1.0,
0.5, 1.0, 0.0);//1,1
///
/// Train the model
///
entity.XGBArray allVectorsTraining = Util.UnionOfXGBArrays(trainClass_0_1,trainClass_1_0,trainClass_0_0,trainClass_1_1);
xgb.Fit(allVectorsTraining.Vectors, allVectorsTraining.Labels);
///
/// Test the model
///
int countTestingPoints = 10;
entity.XGBArray testClass_0_1 = Util.GenerateRandom2dPoints(countTestingPoints ,
0.1, 0.4,
0.6, 0.9, 1.0);//0,1
entity.XGBArray testClass_1_0 = Util.GenerateRandom2dPoints(countTestingPoints,
0.6, 0.9,
0.1, 0.4, 1.0);//1,0
entity.XGBArray testClass_0_0 = Util.GenerateRandom2dPoints(countTestingPoints,
0.1, 0.4,
0.1, 0.4, 0.0);//0,0
entity.XGBArray testClass_1_1 = Util.GenerateRandom2dPoints(countTestingPoints,
0.6, 0.9,
0.6, 0.9, 0.0);//1,1
entity.XGBArray allVectorsTest = Util.UnionOfXGBArrays(testClass_0_1, testClass_1_0,testClass_0_0,testClass_1_1);
var resultsActual = xgb.Predict(allVectorsTest.Vectors);
CollectionAssert.AreEqual(resultsActual, allVectorsTest.Labels);
}
持久保存模型(Persisting a model to file)
一旦对模型进行了训练并发现可以产生令人满意的结果,您便想在生产中使用该模型.方法(Once a model has been trained and found to produce satisfactory results, you would like to use this model in production. The method)SaveModelToFile(SaveModelToFile)将模型保存为二进制文件.静态方法(will persist the model to a binary file. The static method)LoadClassifierFromFile(LoadClassifierFromFile)将为保存的模型补水.(will rehydrate the saved model.)
var xgbTrainer = new XGBoost.XGBClassifier();
///
///Train the model
///
xgbTrainer.SaveModelToFile("SimpleLinearClassifier.dat");
///
///Load the persisted model
///
var xgbProduction = XGBoost.XGBClassifier.LoadClassifierFromFile(fileModel);
虹膜数据集(Iris dataset)
总览(Overview)
资料来源:维基百科(Source:Wikipedia) 数据集包含来自鸢尾花三个物种的每一个的50条记录.该数据集是一个测试案例,用于演示许多统计分类技术.(The data set contains 50 records from each of the three species of the Iris flower. This data set is a test case to demonstrate many statistical classification techniques.)描述列(Describe the columns)
- 鸢尾(Iris-setosa)
- 鸢尾花(Iris-versicolor)
- 鸢尾(Iris-virginica) 最佳(top)
数据结构(Data structure)
资料来源:维基百科(Source: Wikipedia)
从CSV解析IRIS记录(Parsing IRIS records from CSV)
///
///The C# class Iris will be used for capturing a single data row
///
public class Iris
{
public float Col1 { get; set; }
public float Col2 { get; set; }
public float Col3 { get; set; }
public float Col4 { get; set; }
public string Petal { get; set; }
}
///
///The function LoadIris will read the specified file line by line and create an instance of the Iris POCO
///The class TextFieldParser from the assembly Microsoft.VisualBasic is being used here
///
private Iris[] LoadIris(string filename)
{
string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), filename);
List<Iris> records = new List<Iris>();
using (var parser = new TextFieldParser(pathFull))
{
parser.TextFieldType = FieldType.Delimited;
parser.SetDelimiters(",");
while (!parser.EndOfData)
{
var fields = parser.ReadFields();
Iris oRecord = new Iris();
oRecord.Col1 = float.Parse(fields[0]);
oRecord.Col2 = float.Parse(fields[1]);
oRecord.Col3 = float.Parse(fields[2]);
oRecord.Col4 = float.Parse(fields[3]);
oRecord.Petal = fields[4];
records.Add(oRecord);
}
}
从CSV创建特征向量(Creating a feature vector from CSV)
/// <summary>
/// Create XGBoost consumable feature vector from Iris POCO classes
/// </summary>
internal static XGVector<Iris>[] ConvertFromIrisToFeatureVectors(Iris[] records)
{
List<XGVector<Iris>> vectors = new List<XGVector<Iris>>();
foreach (var rec in records)
{
XGVector<Iris> newVector = new XGVector<Iris>();
newVector.Original = rec;
newVector.Features = new float[]
{
rec.Col1, rec.Col2,rec.Col3,rec.Col4
};
newVector.Label = ConvertLabelFromStringToNumeric(rec.Petal);
vectors.Add(newVector);
}
return vectors.ToArray();
}
/// <summary>
/// Converts the string based name of the petal to a numeric representation
/// </summary>
internal static float ConvertLabelFromStringToNumeric(string petal)
{
if (petal.Contains("setosa"))
{
return 0;
}
else if (petal.Contains("versicolor"))
{
return 1.0f;
}
else if (petal.Contains("virginica"))
{
return 2.0f;
}
else
{
throw new NotImplementedException();
}
}
加载IRIS,将所有内容放在一起(Loading IRIS-putting it all together)
[TestMethod]
public void BasicLoadData()
{
string filename = "Iris\\Iris.train.data";
iris.Iris[] records = IrisUtils.LoadIris(filename);
entity.XGVector<iris.Iris>[] vectors = IrisUtils.ConvertFromIrisToFeatureVectors(records);
Assert.IsTrue(records.Length >= 140);
}
培训和测试IRIS(Training and testing IRIS)
[TestMethod]
public void TrainAndTestIris()
{
///
/// Load training vectors
///
string filenameTrain = "Iris\\Iris.train.data";
iris.Iris[] recordsTrain = IrisUtils.LoadIris(filenameTrain);
entity.XGVector<iris.Iris>[] vectorsTrain = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTrain);
///
/// Load testingvectors
///
string filenameTest = "Iris\\Iris.test.data";
iris.Iris[] recordsTest = IrisUtils.LoadIris(filenameTest);
entity.XGVector<iris.Iris>[] vectorsTest = IrisUtils.ConvertFromIrisToFeatureVectors(recordsTest);
int noOfClasses = 3;
var xgbc = new XGBoost.XGBClassifier(objective: "multi:softprob", numClass:3);
entity.XGBArray arrTrain = Util.ConvertToXGBArray(vectorsTrain);
entity.XGBArray arrTest = Util.ConvertToXGBArray(vectorsTest);
xgbc.Fit(arrTrain.Vectors, arrTrain.Labels);
var outcomeTest=xgbc.Predict(arrTest.Vectors);
for(int index=0;index<arrTest.Vectors.Length;index++)
{
string sExpected = IrisUtils.ConvertLabelFromNumericToString(arrTest.Labels[index]);
float[] arrResults = new float[]
{
outcomeTest[index*noOfClasses +0],
outcomeTest[index*noOfClasses +1],
outcomeTest[index*noOfClasses +2]
};
float max = arrResults.Max();
int indexWithMaxValue = Util.GetIndexWithMaxValue(arrResults);
string sActualClass = IrisUtils.ConvertLabelFromNumericToString((float)indexWithMaxValue);
Trace.WriteLine($"{index} Expected={sExpected} Actual={sActualClass}");
Assert.AreEqual(sActualClass, sExpected);
}
string pathFull = System.IO.Path.Combine(Util.GetProjectDir2(), _fileModelIris);
xgbc.SaveModelToFile(pathFull);
}
使用代码(Using the code)
Github(Github)
解决方案结构(Solution structure)
|
|-----XGBoost
|
|-----XGBoostTests
| |
| |---iris
| | |
| | |--Iris.data
| | |
| | |--Iris.test.data
| | |
| | |--Iris.train.data
| | |
| | |--Iris.cs
| | |
| |
| |---IrisUtils.cs
| |
| |---IrisUnitTest.cs
| |
| |---SimpleLinearClassifierTests.cs
| |
| |---XORClassifierTests.cs
| |
|
|
许可
本文以及所有相关的源代码和文件均已获得The Code Project Open License (CPOL)的许可。
C# machine-learning 新闻 翻译