3.3 准备数据

HeatWave ML 适用于标记和未标记的数据。标记数据用于训练和评分机器学习模型。生成预测和解释时需要未标记的数据。

标记数据

标记数据具有特征列和目标列( 标签),如下图所示:

图 3.2 标签数据

显示标记数据集表的图像。

特征列包含用于训练机器学习模型的输入变量。目标列包含 基本真实值,或者换句话说,包含正确答案。训练机器学习模型需要具有真实值的标记数据集。在本指南的上下文中,用于训练机器学习模型的标记数据集称为训练数据集

带有真实值的标记数据集也用于对模型进行评分(计算其准确性和可靠性)。此数据集应具有与训练数据集相同的列,但具有不同的数据集。在本指南的上下文中,用于对模型评分的标记数据集称为验证数据集

未标记数据

未标记数据有特征列但没有目标列(没有答案),如下图所示:

图 3.3 未标记数据

显示未标记数据集表的图像。

需要未标记的数据来生成预测和解释。它必须具有与训练数据集完全相同的特征列,但没有目标列。在本指南的上下文中,用于预测和解释的未标记数据称为测试数据集。测试数据以标记数据开始,但为了试验机器学习模型而删除了标签。

您最终将与您的模型一起使用以进行预测的看不见的数据也是未标记的数据。与 测试数据集一样,未见过的数据必须具有与训练数据集完全相同的特征列,但没有目标列。

有关训练、验证和测试数据集表的示例及其结构,请参阅 示例数据第 6.3 节“Iris 数据集机器学习快速入门”

一般数据要求

HeatWave ML 数据的一般要求包括以下内容:

  • 每个数据集必须驻留在 MySQL 数据库系统的单个表中。HeatWave ML 例程,例如 ML_TRAINML_PREDICT_TABLEML_EXPLAIN_TABLE 在单个表上运行。

    有关将数据加载到 MySQL 数据库系统的信息,请参阅 导入和导出数据库

  • 用于 HeatWave ML 的表不得超过 10 GB、1 亿行或 900 列。

  • 表列必须使用受支持的数据类型。有关支持的数据类型和如何处理不支持的类型的建议,请参阅第 3.11 节,“支持的数据类型”

  • NaN(非数字)值无法被 MySQL 识别,应替换为NULL.

  • 分类模型的训练数据集中的目标列必须至少有两个不同的值,并且每个不同的值应至少出现在五行中。对于回归模型,只允许使用数字目标列。

笔记

ML_TRAIN 例程将忽略其值缺失超过 20% 的列以及每行中具有相同值的列。数值列中的缺失值替换为列的平均值,标准化为均值为 0,标准差为 1。分类列中的缺失值替换为最常见的值,以及单热或有序编码用于将分类值转换为数值。MySQL 数据库中存在的输入数据不会被 ML_TRAIN.

示例数据

本指南中的示例使用人口普查收入数据集

Dua, D. 和 Graff, C.(2019 年)。UCI 机器学习库 [ http://archive.ics.uci.edu/ml ]。加州欧文市:加州大学信息学院。

笔记

人口普查收入数据集示例演示了classification训练和推理。HeatWave ML 还支持 regression针对适合该目的的数据集进行训练和推理。该 参数定义机器学习模型是否针对 或 。 ML_TRAIN taskclassificationregression

要复制本指南中的示例,请执行以下步骤来创建所需的架构和表。需要 Python 3 和 MySQL Shell。

  1. 通过执行以下语句在 MySQL 数据库系统上创建以下模式和表:

    CREATE SCHEMA heatwaveml_bench;
    
    USE heatwaveml_bench;
    
    CREATE TABLE census_train ( 
      age INT, workclass VARCHAR(255), 
      fnlwgt INT, education VARCHAR(255), 
      `education-num` INT, 
      `marital-status` VARCHAR(255), 
      occupation VARCHAR(255), 
      relationship VARCHAR(255), 
      race VARCHAR(255), 
      sex VARCHAR(255), 
      `capital-gain` INT, 
      `capital-loss` INT, 
      `hours-per-week` INT, 
      `native-country` VARCHAR(255), 
      revenue VARCHAR(255));
      
    CREATE TABLE `census_test` LIKE `census_train`;
  2. 在https://github.com/oracle-samples/heatwave-ml 导航到HeatWave ML 代码的性能基准GitHub 存储库 。

  3. 按照README.md说明创建census_train.csvcensus_test.csv数据文件。总之,说明是:

    1. 安装所需的 Python 包:

      pip install pandas==1.2.3 numpy==1.22.2 unlzw3==0.2.1 sklearn==1.0.2
    2. 下载或克隆存储库,其中包括人口普查源数据和预处理脚本。

    3. 运行preprocess.py脚本以创建census_train.csvcensus_test.csv数据文件。

      python3 heatwave-ml/preprocess.py --benchmark census
    笔记

    不要按照 README.md文件中的说明运行基准测试。基准测试脚本在处理结束时删除模式和数据。

  4. --mysql 启动带有打开选项的 MySQL Shell ClassicSession,这在使用 并行表导入实用程序时是必需的。

    mysqlsh --mysql Username@IPAddressOfMySQLDBSystemEndpoint
  5. .csv使用以下命令 将文件中的数据加载到 MySQL 数据库系统中:

    MySQL>JS> util.importTable("census_train.csv",{table: "census_train", 
    dialect: "csv-unix", skipRows:1})
    
    MySQL>JS> util.importTable("census_test.csv",{table: "census_test", 
    dialect: "csv-unix", skipRows:1})
  6. 创建验证表:

    CREATE TABLE `census_validate` LIKE `census_test`;
    
    INSERT INTO `census_validate` SELECT * FROM `census_test`;
  7. 修改census_test表以删除目标`revenue`列:

    ALTER TABLE `census_test` DROP COLUMN `revenue`;

其他示例数据集

有关可与 HeatWave ML 一起使用的其他示例数据集,请参阅HeatWave ML 性能基准代码GitHub 存储库, 网址为https://github.com/oracle-samples/heatwave-ml