本教程展示了一个使用 HeatWave ML 创建和使用预测机器学习模型的端到端示例。它引导您完成准备数据、使用
ML_TRAIN
例程训练模型以及使用
ML_PREDICT_*
和
ML_EXPLAIN_*
例程生成预测和解释。本教程还演示了如何使用
ML_SCORE
例程评估模型的质量,以及如何查看模型说明以了解模型的工作原理。
有关基于本教程的在线研讨会,请参阅 MySQL HeatWave 机器学习入门。
本教程使用 UCI 机器学习存储库中公开可用的 Iris 数据集。
Dua, D. 和 Graff, C.(2019 年)。UCI 机器学习库 [ http://archive.ics.uci.edu/ml ]。加州欧文市:加州大学信息学院。
鸢尾花数据集有
以下数据,其中萼片和花瓣特征用于预测
class
标签,即鸢尾花植物的类型:
萼片长度 (cm)
萼片宽度 (cm)
花瓣长度(厘米)
花瓣宽度(厘米)
-
班级。可能的值包括:
鸢尾
杂色鸢尾
弗吉尼亚鸢尾
数据以以下模式和表存储在 MySQL 数据库中:
ml_data
模式:包含训练和测试数据集表的模式。iris_train
表:训练数据集(已标记)。class
包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个带有基本真实值 的填充 目标列。iris_test
表:测试数据集(未标记)。包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)但没有目标列。iris_validate
表:验证数据集(已标记)。class
包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个带有基本真实值 的填充 目标列。
本教程假定您已满足第 3.1 节“开始之前”中概述的先决条件。
-
通过执行以下语句在 MySQL 数据库系统上创建示例模式和表:
CREATE SCHEMA ml_data; USE ml_data; CREATE TABLE `iris_train` ( `sepal length` float DEFAULT NULL, `sepal width` float DEFAULT NULL, `petal length` float DEFAULT NULL, `petal width` float DEFAULT NULL, `class` varchar(16) DEFAULT NULL ); INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.2,'Iris-virginica'); INSERT INTO iris_train VALUES(5.0,2.3,3.3,1.0,'Iris-setosa'); INSERT INTO iris_train VALUES(4.9,2.5,4.5,1.7,'Iris-virginica'); INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.7,3.8,1.7,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.4,3.2,1.3,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.4,3.4,1.5,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.9,3.1,5.1,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(6.7,3.1,4.4,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(5.1,3.7,1.5,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.2,2.7,3.9,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(6.9,3.1,4.9,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(5.8,4.0,1.2,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.4,3.9,1.7,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.7,3.8,6.7,2.2,'Iris-virginica'); INSERT INTO iris_train VALUES(6.3,3.3,4.7,1.6,'Iris-setosa'); INSERT INTO iris_train VALUES(6.8,3.2,5.9,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(7.6,3.0,6.6,2.1,'Iris-virginica'); INSERT INTO iris_train VALUES(6.4,3.2,5.3,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(5.7,4.4,1.5,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.7,3.3,5.7,2.1,'Iris-virginica'); INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.1,'Iris-virginica'); INSERT INTO iris_train VALUES(5.4,3.9,1.3,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.1,2.6,5.6,1.4,'Iris-virginica'); INSERT INTO iris_train VALUES(7.2,3.0,5.8,1.6,'Iris-virginica'); INSERT INTO iris_train VALUES(5.2,3.5,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.8,2.6,4.0,1.2,'Iris-setosa'); INSERT INTO iris_train VALUES(5.9,3.0,5.1,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(5.4,3.0,4.5,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(6.7,3.0,5.0,1.7,'Iris-setosa'); INSERT INTO iris_train VALUES(6.3,2.3,4.4,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(5.1,2.5,3.0,1.1,'Iris-setosa'); INSERT INTO iris_train VALUES(6.4,3.2,4.5,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(6.8,3.0,5.5,2.1,'Iris-virginica'); INSERT INTO iris_train VALUES(6.2,2.8,4.8,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(6.9,3.2,5.7,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(6.5,3.2,5.1,2.0,'Iris-virginica'); INSERT INTO iris_train VALUES(5.8,2.8,5.1,2.4,'Iris-virginica'); INSERT INTO iris_train VALUES(5.1,3.8,1.5,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.9,3.8,6.4,2.0,'Iris-virginica'); INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica'); INSERT INTO iris_train VALUES(6.7,3.0,5.2,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(5.1,3.8,1.9,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.7,3.2,1.6,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.0,2.2,5.0,1.5,'Iris-virginica'); INSERT INTO iris_train VALUES(4.8,3.4,1.6,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.7,2.6,6.9,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(4.6,3.6,1.0,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.2,3.2,6.0,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(5.0,3.3,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.6,3.0,4.4,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(6.1,2.8,4.0,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(5.0,3.2,1.2,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.0,3.2,4.7,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(6.0,3.0,4.8,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(7.4,2.8,6.1,1.9,'Iris-virginica'); INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica'); INSERT INTO iris_train VALUES(6.2,3.4,5.4,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(5.0,2.0,3.5,1.0,'Iris-setosa'); INSERT INTO iris_train VALUES(5.6,2.5,3.9,1.1,'Iris-setosa'); INSERT INTO iris_train VALUES(6.7,3.1,5.6,2.4,'Iris-virginica'); INSERT INTO iris_train VALUES(6.3,2.5,5.0,1.9,'Iris-virginica'); INSERT INTO iris_train VALUES(6.4,3.1,5.5,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(6.2,2.2,4.5,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(7.3,2.9,6.3,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(4.4,3.0,1.3,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.2,3.6,6.1,2.5,'Iris-virginica'); INSERT INTO iris_train VALUES(6.5,3.0,5.5,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(5.0,3.4,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.7,3.2,1.3,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.6,2.9,4.6,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(5.5,3.5,1.3,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.7,3.0,6.1,2.3,'Iris-virginica'); INSERT INTO iris_train VALUES(6.1,3.0,4.9,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.5,2.4,3.8,1.1,'Iris-setosa'); INSERT INTO iris_train VALUES(5.7,2.9,4.2,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(6.0,2.9,4.5,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(6.4,2.7,5.3,1.9,'Iris-virginica'); INSERT INTO iris_train VALUES(5.4,3.7,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.1,2.9,4.7,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(6.5,2.8,4.6,1.5,'Iris-setosa'); INSERT INTO iris_train VALUES(5.6,2.7,4.2,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(6.3,3.4,5.6,2.4,'Iris-virginica'); INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.8,2.8,4.8,1.4,'Iris-setosa'); INSERT INTO iris_train VALUES(5.7,2.8,4.5,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(6.0,2.7,5.1,1.6,'Iris-setosa'); INSERT INTO iris_train VALUES(5.0,3.5,1.3,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.5,3.0,5.2,2.0,'Iris-virginica'); INSERT INTO iris_train VALUES(6.1,2.8,4.7,1.2,'Iris-setosa'); INSERT INTO iris_train VALUES(5.1,3.5,1.4,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.6,3.1,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.5,3.0,5.8,2.2,'Iris-virginica'); INSERT INTO iris_train VALUES(4.6,3.4,1.4,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.6,3.2,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(7.7,2.8,6.7,2.0,'Iris-virginica'); INSERT INTO iris_train VALUES(5.9,3.2,4.8,1.8,'Iris-setosa'); INSERT INTO iris_train VALUES(5.1,3.8,1.6,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.9,3.0,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.9,2.4,3.3,1.0,'Iris-setosa'); INSERT INTO iris_train VALUES(4.5,2.3,1.3,0.3,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.8,2.7,4.1,1.0,'Iris-setosa'); INSERT INTO iris_train VALUES(5.0,3.4,1.6,0.4,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.2,3.4,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.3,3.7,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.0,3.6,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.6,2.9,3.6,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(4.8,3.1,1.6,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.3,2.7,4.9,1.8,'Iris-virginica'); INSERT INTO iris_train VALUES(5.7,2.8,4.1,1.3,'Iris-setosa'); INSERT INTO iris_train VALUES(5.0,3.0,1.6,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(6.3,3.3,6.0,2.5,'Iris-virginica'); INSERT INTO iris_train VALUES(5.0,3.5,1.6,0.6,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.5,2.6,4.4,1.2,'Iris-setosa'); INSERT INTO iris_train VALUES(5.7,3.0,4.2,1.2,'Iris-setosa'); INSERT INTO iris_train VALUES(4.4,2.9,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.1,'Iris-versicolor'); INSERT INTO iris_train VALUES(5.5,2.4,3.7,1.0,'Iris-setosa'); CREATE TABLE `iris_test` LIKE `iris_train`; INSERT INTO iris_test VALUES(5.9,3.0,4.2,1.5,'Iris-setosa'); INSERT INTO iris_test VALUES(6.9,3.1,5.4,2.1,'Iris-virginica'); INSERT INTO iris_test VALUES(5.1,3.3,1.7,0.5,'Iris-versicolor'); INSERT INTO iris_test VALUES(6.0,3.4,4.5,1.6,'Iris-setosa'); INSERT INTO iris_test VALUES(5.5,2.5,4.0,1.3,'Iris-setosa'); INSERT INTO iris_test VALUES(6.2,2.9,4.3,1.3,'Iris-setosa'); INSERT INTO iris_test VALUES(5.5,4.2,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_test VALUES(6.3,2.8,5.1,1.5,'Iris-virginica'); INSERT INTO iris_test VALUES(5.6,3.0,4.1,1.3,'Iris-setosa'); INSERT INTO iris_test VALUES(6.7,2.5,5.8,1.8,'Iris-virginica'); INSERT INTO iris_test VALUES(7.1,3.0,5.9,2.1,'Iris-virginica'); INSERT INTO iris_test VALUES(4.3,3.0,1.1,0.1,'Iris-versicolor'); INSERT INTO iris_test VALUES(5.6,2.8,4.9,2.0,'Iris-virginica'); INSERT INTO iris_test VALUES(5.5,2.3,4.0,1.3,'Iris-setosa'); INSERT INTO iris_test VALUES(6.0,2.2,4.0,1.0,'Iris-setosa'); INSERT INTO iris_test VALUES(5.1,3.5,1.4,0.2,'Iris-versicolor'); INSERT INTO iris_test VALUES(5.7,2.6,3.5,1.0,'Iris-setosa'); INSERT INTO iris_test VALUES(4.8,3.4,1.9,0.2,'Iris-versicolor'); INSERT INTO iris_test VALUES(5.1,3.4,1.5,0.2,'Iris-versicolor'); INSERT INTO iris_test VALUES(5.7,2.5,5.0,2.0,'Iris-virginica'); INSERT INTO iris_test VALUES(5.4,3.4,1.7,0.2,'Iris-versicolor'); INSERT INTO iris_test VALUES(5.6,3.0,4.5,1.5,'Iris-setosa'); INSERT INTO iris_test VALUES(6.3,2.9,5.6,1.8,'Iris-virginica'); INSERT INTO iris_test VALUES(6.3,2.5,4.9,1.5,'Iris-setosa'); INSERT INTO iris_test VALUES(5.8,2.7,3.9,1.2,'Iris-setosa'); INSERT INTO iris_test VALUES(6.1,3.0,4.6,1.4,'Iris-setosa'); INSERT INTO iris_test VALUES(5.2,4.1,1.5,0.1,'Iris-versicolor'); INSERT INTO iris_test VALUES(6.7,3.1,4.7,1.5,'Iris-setosa'); INSERT INTO iris_test VALUES(6.7,3.3,5.7,2.5,'Iris-virginica'); INSERT INTO iris_test VALUES(6.4,2.9,4.3,1.3,'Iris-setosa'); CREATE TABLE `iris_validate` LIKE `iris_test`; INSERT INTO `iris_validate` SELECT * FROM `iris_test`; ALTER TABLE `iris_test` DROP COLUMN `class`;
-
使用 训练模型
ML_TRAIN
。由于这是一个分类数据集,classification
任务指定创建一个分类模型:mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', JSON_OBJECT('task', 'classification'), @iris_model);
训练操作完成后,模型句柄将分配给
@iris_model
会话变量,模型将存储在您的模型目录中。您可以使用以下查询查看模型目录中的条目,其中user1
是您的 MySQL 帐户名:mysql> SELECT model_id, model_handle, train_table_name FROM ML_SCHEMA_user1.MODEL_CATALOG; +----------+---------------------------------------+--------------------+ | model_id | model_handle | train_table_name | +----------+---------------------------------------+--------------------+ | 1 | ml_data.iris_train_user1_1648140791 | ml_data.iris_train | +----------+---------------------------------------+--------------------+
-
ML_MODEL_LOAD
使用例程 将模型加载到 HeatWave ML 中 :mysql> CALL sys.ML_MODEL_LOAD(@iris_model, NULL);
必须加载模型才能使用它。模型保持加载状态,直到您卸载它或重新启动 HeatWave Cluster。
-
ML_PREDICT_ROW
使用例程 对单行数据进行预测 。在此示例中,数据被分配给一个@row_input
会话变量,该变量由例程调用。使用@iris_model
会话变量调用模型句柄:mysql> SET @row_input = JSON_OBJECT( "sepal length", 7.3, "sepal width", 2.9, "petal length", 6.3, "petal width", 1.8); mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @iris_model); ----------------------------------------------------------------------------+ | sys.ML_PREDICT_ROW(@row_input, @iris_model) | +---------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9, | | "petal length": 6.3, "sepal length": 7.3} | +---------------------------------------------------------------------------+
根据提供的特征输入,模型预测 Iris 植物属于类别
Iris-virginica
。还显示了用于进行预测的特征值。 -
现在,使用例程为同一行数据生成解释,
ML_EXPLAIN_ROW
以了解预测是如何进行的:mysql> SELECT sys.ML_EXPLAIN_ROW(@row_input, @iris_model); +------------------------------------------------------------------------------+ | sys.ML_EXPLAIN_ROW(@row_input, @iris_model) | +------------------------------------------------------------------------------+ | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9, | | "petal length": 6.3, "sepal length": 7.3, "petal width_attribution": 0.2496, | | "petal length_attribution": 0.9997} | +------------------------------------------------------------------------------+
属性值显示哪些特征对预测贡献最大,花瓣长度和踏板宽度是最重要的特征。其他特征的值为 0,表示它们对预测没有贡献。
-
ML_PREDICT_TABLE
使用例程 对数据表进行预测 。该例程将iris_test
表中的数据作为输入并将预测写入iris_predictions
输出表。mysql> CALL sys.ML_PREDICT_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_predictions');
查看
ML_PREDICT_TABLE
结果,查询输出表;例如:mysql> SELECT * FROM iris_predictions LIMIT 3\G *************************** 1. row *************************** sepal length: 5.9 sepal width: 3 petal length: 4.2 petal width: 1.5 Prediction: Iris-setosa *************************** 2. row *************************** sepal length: 6.9 sepal width: 3.1 petal length: 5.4 petal width: 2.1 Prediction: Iris-virginica *************************** 3. row *************************** sepal length: 5.1 sepal width: 3.3 petal length: 1.7 petal width: 0.5 Prediction: Iris-versicolor
该表显示了预测和用于进行每个预测的特征列值。
-
ML_EXPLAIN_TABLE
使用例程 为同一数据表生成解释 。解释可帮助您了解哪些特征对预测的影响最大。特征重要性表示为范围从 -1 到 1 的属性值。正值表示特征对预测有贡献。负值表示该特征对其他可能的预测之一有积极贡献。
mysql> CALL sys.ML_EXPLAIN_TABLE('ml_data.iris_test', @iris_model, 'ml_data.iris_explanations');
查看
ML_EXPLAIN_TABLE
结果,查询输出表;例如:mysql> SELECT * FROM iris_explanations LIMIT 3\G; *************************** 1. row *************************** sepal length: 5.9 sepal width: 3 petal length: 4.2 petal width: 1.5 Prediction: Iris-setosa petal length_attribution: -0.0088 petal width_attribution: 0.1793 *************************** 2. row *************************** sepal length: 6.9 sepal width: 3.1 petal length: 5.4 petal width: 2.1 Prediction: Iris-virginica petal length_attribution: 0.9723 petal width_attribution: 0.6712 *************************** 3. row *************************** sepal length: 5.1 sepal width: 3.3 petal length: 1.7 petal width: 0.5 Prediction: Iris-versicolor petal length_attribution: 0.5373 petal width_attribution: 0.3529 3 rows in set (0.0006 sec)
-
使用评分模型
ML_SCORE
来评估模型的可靠性。此示例使用balanced_accuracy
指标,它是 HeatWave ML 支持的众多评分指标之一。mysql> CALL sys.ML_SCORE('ml_data.iris_validate', 'class', @iris_model, 'balanced_accuracy', @score);
要检索计算出的分数,请查询
@score
会话变量。mysql> SELECT @score; +--------------------+ | @score | +--------------------+ | 0.9583333134651184 | +--------------------+
-
使用卸载模型
ML_MODEL_UNLOAD
:mysql> CALL sys.ML_MODEL_UNLOAD(@iris_model);
为避免占用太多空间,最好在使用完模型后将其卸载。