MySQL HeatWave 用户指南  / 第 6 章 HeatWave 快速入门  /  6.3 Iris 数据集机器学习快速入门

6.3 Iris 数据集机器学习快速入门

本教程展示了一个使用 HeatWave ML 创建和使用预测机器学习模型的端到端示例。它引导您完成准备数据、使用 ML_TRAIN 例程训练模型以及使用 ML_PREDICT_*ML_EXPLAIN_*例程生成预测和解释。本教程还演示了如何使用 ML_SCORE 例程评估模型的质量,以及如何查看模型说明以了解模型的工作原理。

有关基于本教程的在线研讨会,请参阅 MySQL HeatWave 机器学习入门

本教程使用 UCI 机器学习存储库中公开可用的 Iris 数据集

Dua, D. 和 Graff, C.(2019 年)。UCI 机器学习库 [ http://archive.ics.uci.edu/ml ]。加州欧文市:加州大学信息学院。

鸢尾花数据集有 以下数据,其中萼片和花瓣特征用于预测 class标签,即鸢尾花植物的类型:

  • 萼片长度 (cm)

  • 萼片宽度 (cm)

  • 花瓣长度(厘米)

  • 花瓣宽度(厘米)

  • 班级。可能的值包括:

    • 鸢尾

    • 杂色鸢尾

    • 弗吉尼亚鸢尾

数据以以下模式和表存储在 MySQL 数据库中:

  • ml_data模式:包含训练和测试数据集表的模式。

  • iris_train表:训练数据集(已标记)。class包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个带有基本真实值 的填充 目标列。

  • iris_test表:测试数据集(未标记)。包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)但没有目标列。

  • iris_validate表:验证数据集(已标记)。class包括特征列(萼片长度、萼片宽度、花瓣长度、花瓣宽度)和一个带有基本真实值 的填充 目标列。

本教程假定您已满足第 3.1 节“开始之前”中概述的先决条件。

  1. 通过执行以下语句在 MySQL 数据库系统上创建示例模式和表:

    CREATE SCHEMA ml_data;
    
    USE ml_data;
    
    CREATE TABLE `iris_train` (
      `sepal length` float DEFAULT NULL,
      `sepal width` float DEFAULT NULL,
      `petal length` float DEFAULT NULL,
      `petal width` float DEFAULT NULL,
      `class` varchar(16) DEFAULT NULL
    );
    
    INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.2,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.0,2.3,3.3,1.0,'Iris-setosa');
    INSERT INTO iris_train VALUES(4.9,2.5,4.5,1.7,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.7,3.8,1.7,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.4,3.2,1.3,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.4,3.4,1.5,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.9,3.1,5.1,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.7,3.1,4.4,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.1,3.7,1.5,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.2,2.7,3.9,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.9,3.1,4.9,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.8,4.0,1.2,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.4,3.9,1.7,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.7,3.8,6.7,2.2,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.3,3.3,4.7,1.6,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.8,3.2,5.9,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(7.6,3.0,6.6,2.1,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.4,3.2,5.3,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.7,4.4,1.5,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.7,3.3,5.7,2.1,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.4,2.8,5.6,2.1,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.4,3.9,1.3,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.1,2.6,5.6,1.4,'Iris-virginica');
    INSERT INTO iris_train VALUES(7.2,3.0,5.8,1.6,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.2,3.5,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.8,2.6,4.0,1.2,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.9,3.0,5.1,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.4,3.0,4.5,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.7,3.0,5.0,1.7,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.3,2.3,4.4,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.1,2.5,3.0,1.1,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.4,3.2,4.5,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.8,3.0,5.5,2.1,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.2,2.8,4.8,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.9,3.2,5.7,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.5,3.2,5.1,2.0,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.8,2.8,5.1,2.4,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.1,3.8,1.5,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.9,3.8,6.4,2.0,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.7,3.0,5.2,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.1,3.8,1.9,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.7,3.2,1.6,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.0,2.2,5.0,1.5,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.8,3.4,1.6,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.7,2.6,6.9,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.6,3.6,1.0,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.2,3.2,6.0,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.0,3.3,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.6,3.0,4.4,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.1,2.8,4.0,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.0,3.2,1.2,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.0,3.2,4.7,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.0,3.0,4.8,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(7.4,2.8,6.1,1.9,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.8,2.7,5.1,1.9,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.2,3.4,5.4,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.0,2.0,3.5,1.0,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.6,2.5,3.9,1.1,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.7,3.1,5.6,2.4,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.3,2.5,5.0,1.9,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.4,3.1,5.5,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.2,2.2,4.5,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(7.3,2.9,6.3,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.4,3.0,1.3,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.2,3.6,6.1,2.5,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.5,3.0,5.5,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.0,3.4,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.7,3.2,1.3,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.6,2.9,4.6,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.5,3.5,1.3,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.7,3.0,6.1,2.3,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.1,3.0,4.9,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.5,2.4,3.8,1.1,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.7,2.9,4.2,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.0,2.9,4.5,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.4,2.7,5.3,1.9,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.4,3.7,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.1,2.9,4.7,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.5,2.8,4.6,1.5,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.6,2.7,4.2,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.3,3.4,5.6,2.4,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.9,3.1,1.5,0.1,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.8,2.8,4.8,1.4,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.7,2.8,4.5,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(6.0,2.7,5.1,1.6,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.0,3.5,1.3,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.5,3.0,5.2,2.0,'Iris-virginica');
    INSERT INTO iris_train VALUES(6.1,2.8,4.7,1.2,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.1,3.5,1.4,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.6,3.1,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.5,3.0,5.8,2.2,'Iris-virginica');
    INSERT INTO iris_train VALUES(4.6,3.4,1.4,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.6,3.2,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(7.7,2.8,6.7,2.0,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.9,3.2,4.8,1.8,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.1,3.8,1.6,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.9,3.0,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.9,2.4,3.3,1.0,'Iris-setosa');
    INSERT INTO iris_train VALUES(4.5,2.3,1.3,0.3,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.8,2.7,4.1,1.0,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.0,3.4,1.6,0.4,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.2,3.4,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.3,3.7,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.0,3.6,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.6,2.9,3.6,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(4.8,3.1,1.6,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.3,2.7,4.9,1.8,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.7,2.8,4.1,1.3,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.0,3.0,1.6,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(6.3,3.3,6.0,2.5,'Iris-virginica');
    INSERT INTO iris_train VALUES(5.0,3.5,1.6,0.6,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.5,2.6,4.4,1.2,'Iris-setosa');
    INSERT INTO iris_train VALUES(5.7,3.0,4.2,1.2,'Iris-setosa');
    INSERT INTO iris_train VALUES(4.4,2.9,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_train VALUES(4.8,3.0,1.4,0.1,'Iris-versicolor');
    INSERT INTO iris_train VALUES(5.5,2.4,3.7,1.0,'Iris-setosa');
    
    CREATE TABLE `iris_test` LIKE `iris_train`;
    
    INSERT INTO iris_test VALUES(5.9,3.0,4.2,1.5,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.9,3.1,5.4,2.1,'Iris-virginica');
    INSERT INTO iris_test VALUES(5.1,3.3,1.7,0.5,'Iris-versicolor');
    INSERT INTO iris_test VALUES(6.0,3.4,4.5,1.6,'Iris-setosa');
    INSERT INTO iris_test VALUES(5.5,2.5,4.0,1.3,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.2,2.9,4.3,1.3,'Iris-setosa');
    INSERT INTO iris_test VALUES(5.5,4.2,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_test VALUES(6.3,2.8,5.1,1.5,'Iris-virginica');
    INSERT INTO iris_test VALUES(5.6,3.0,4.1,1.3,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.7,2.5,5.8,1.8,'Iris-virginica');
    INSERT INTO iris_test VALUES(7.1,3.0,5.9,2.1,'Iris-virginica');
    INSERT INTO iris_test VALUES(4.3,3.0,1.1,0.1,'Iris-versicolor');
    INSERT INTO iris_test VALUES(5.6,2.8,4.9,2.0,'Iris-virginica');
    INSERT INTO iris_test VALUES(5.5,2.3,4.0,1.3,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.0,2.2,4.0,1.0,'Iris-setosa');
    INSERT INTO iris_test VALUES(5.1,3.5,1.4,0.2,'Iris-versicolor');
    INSERT INTO iris_test VALUES(5.7,2.6,3.5,1.0,'Iris-setosa');
    INSERT INTO iris_test VALUES(4.8,3.4,1.9,0.2,'Iris-versicolor');
    INSERT INTO iris_test VALUES(5.1,3.4,1.5,0.2,'Iris-versicolor');
    INSERT INTO iris_test VALUES(5.7,2.5,5.0,2.0,'Iris-virginica');
    INSERT INTO iris_test VALUES(5.4,3.4,1.7,0.2,'Iris-versicolor');
    INSERT INTO iris_test VALUES(5.6,3.0,4.5,1.5,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.3,2.9,5.6,1.8,'Iris-virginica');
    INSERT INTO iris_test VALUES(6.3,2.5,4.9,1.5,'Iris-setosa');
    INSERT INTO iris_test VALUES(5.8,2.7,3.9,1.2,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.1,3.0,4.6,1.4,'Iris-setosa');
    INSERT INTO iris_test VALUES(5.2,4.1,1.5,0.1,'Iris-versicolor');
    INSERT INTO iris_test VALUES(6.7,3.1,4.7,1.5,'Iris-setosa');
    INSERT INTO iris_test VALUES(6.7,3.3,5.7,2.5,'Iris-virginica');
    INSERT INTO iris_test VALUES(6.4,2.9,4.3,1.3,'Iris-setosa');
    
    CREATE TABLE `iris_validate` LIKE `iris_test`;
    
    INSERT INTO `iris_validate` SELECT * FROM `iris_test`;
    
    ALTER TABLE `iris_test` DROP COLUMN `class`;
  2. 使用 训练模型 ML_TRAIN。由于这是一个分类数据集, classification任务指定创建一个分类模型:

    mysql> CALL sys.ML_TRAIN('ml_data.iris_train', 'class', 
    JSON_OBJECT('task', 'classification'), @iris_model);

    训练操作完成后,模型句柄将分配给@iris_model会话变量,模型将存储在您的模型目录中。您可以使用以下查询查看模型目录中的条目,其中user1是您的 MySQL 帐户名:

    mysql> SELECT model_id, model_handle, train_table_name FROM ML_SCHEMA_user1.MODEL_CATALOG;
    +----------+---------------------------------------+--------------------+
    | model_id | model_handle                          | train_table_name   |
    +----------+---------------------------------------+--------------------+
    |        1 | ml_data.iris_train_user1_1648140791   | ml_data.iris_train |
    +----------+---------------------------------------+--------------------+
  3. ML_MODEL_LOAD 使用例程 将模型加载到 HeatWave ML 中 :

    mysql> CALL sys.ML_MODEL_LOAD(@iris_model, NULL);

    必须加载模型才能使用它。模型保持加载状态,直到您卸载它或重新启动 HeatWave Cluster。

  4. ML_PREDICT_ROW 使用例程 对单行数据进行预测 。在此示例中,数据被分配给一个 @row_input会话变量,该变量由例程调用。使用@iris_model会话变量调用模型句柄:

    mysql> SET @row_input = JSON_OBJECT( 
    "sepal length", 7.3, 
    "sepal width", 2.9, 
    "petal length", 6.3, 
    "petal width", 1.8);  
    
    mysql> SELECT sys.ML_PREDICT_ROW(@row_input, @iris_model);
    ----------------------------------------------------------------------------+
    | sys.ML_PREDICT_ROW(@row_input, @iris_model)                               |
    +---------------------------------------------------------------------------+
    | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9,  |
    | "petal length": 6.3, "sepal length": 7.3}                                 |
    +---------------------------------------------------------------------------+

    根据提供的特征输入,模型预测 Iris 植物属于类别 Iris-virginica。还显示了用于进行预测的特征值。

  5. 现在,使用例程为同一行数据生成解释, ML_EXPLAIN_ROW 以了解预测是如何进行的:

    mysql> SELECT sys.ML_EXPLAIN_ROW(@row_input, @iris_model);
    +------------------------------------------------------------------------------+
    | sys.ML_EXPLAIN_ROW(@row_input, @iris_model)                                  |
    +------------------------------------------------------------------------------+
    | {"Prediction": "Iris-virginica", "petal width": 1.8, "sepal width": 2.9,     |
    | "petal length": 6.3, "sepal length": 7.3, "petal width_attribution": 0.2496, |
    | "petal length_attribution": 0.9997}                                          |
    +------------------------------------------------------------------------------+

    属性值显示哪些特征对预测贡献最大,花瓣长度和踏板宽度是最重要的特征。其他特征的值为 0,表示它们对预测没有贡献。

  6. ML_PREDICT_TABLE 使用例程 对数据表进行预测 。该例程将 iris_test表中的数据作为输入并将预测写入iris_predictions输出表。

    mysql> CALL sys.ML_PREDICT_TABLE('ml_data.iris_test', @iris_model, 
    'ml_data.iris_predictions');

    查看 ML_PREDICT_TABLE 结果,查询输出表;例如:

    mysql> SELECT * FROM iris_predictions LIMIT 3\G
    *************************** 1. row ***************************
    sepal length: 5.9
     sepal width: 3
    petal length: 4.2
     petal width: 1.5
      Prediction: Iris-setosa
    *************************** 2. row ***************************
    sepal length: 6.9
     sepal width: 3.1
    petal length: 5.4
     petal width: 2.1
      Prediction: Iris-virginica
    *************************** 3. row ***************************
    sepal length: 5.1
     sepal width: 3.3
    petal length: 1.7
     petal width: 0.5
      Prediction: Iris-versicolor

    该表显示了预测和用于进行每个预测的特征列值。

  7. ML_EXPLAIN_TABLE 使用例程 为同一数据表生成解释 。

    解释可帮助您了解哪些特征对预测的影响最大。特征重要性表示为范围从 -1 到 1 的属性值。正值表示特征对预测有贡献。负值表示该特征对其他可能的预测之一有积极贡献。

    mysql> CALL sys.ML_EXPLAIN_TABLE('ml_data.iris_test', @iris_model, 
    'ml_data.iris_explanations');

    查看 ML_EXPLAIN_TABLE 结果,查询输出表;例如:

    mysql> SELECT * FROM iris_explanations LIMIT 3\G;
    *************************** 1. row ***************************
                sepal length: 5.9
                 sepal width: 3
                petal length: 4.2
                 petal width: 1.5
                  Prediction: Iris-setosa
    petal length_attribution: -0.0088
     petal width_attribution: 0.1793
    *************************** 2. row ***************************
                sepal length: 6.9
                 sepal width: 3.1
                petal length: 5.4
                 petal width: 2.1
                  Prediction: Iris-virginica
    petal length_attribution: 0.9723
     petal width_attribution: 0.6712
    *************************** 3. row ***************************
                sepal length: 5.1
                 sepal width: 3.3
                petal length: 1.7
                 petal width: 0.5
                  Prediction: Iris-versicolor
    petal length_attribution: 0.5373
     petal width_attribution: 0.3529
    3 rows in set (0.0006 sec)
  8. 使用评分模型 ML_SCORE 来评估模型的可靠性。此示例使用 balanced_accuracy指标,它是 HeatWave ML 支持的众多评分指标之一。

    mysql> CALL sys.ML_SCORE('ml_data.iris_validate', 'class', @iris_model, 
    'balanced_accuracy', @score);

    要检索计算出的分数,请查询 @score会话变量。

    mysql> SELECT @score;
    +--------------------+
    | @score             |
    +--------------------+
    | 0.9583333134651184 |
    +--------------------+
  9. 使用卸载模型 ML_MODEL_UNLOAD

    mysql> CALL sys.ML_MODEL_UNLOAD(@iris_model);

    为避免占用太多空间,最好在使用完模型后将其卸载。