HeatWave ML 适用于标记和未标记的数据。标记数据用于训练和评分机器学习模型。生成预测和解释时需要未标记的数据。
标记数据
标记数据具有特征列和目标列( 标签),如下图所示:
特征列包含用于训练机器学习模型的输入变量。目标列包含 基本真实值,或者换句话说,包含正确答案。训练机器学习模型需要具有真实值的标记数据集。在本指南的上下文中,用于训练机器学习模型的标记数据集称为训练数据集。
带有真实值的标记数据集也用于对模型进行评分(计算其准确性和可靠性)。此数据集应具有与训练数据集相同的列,但具有不同的数据集。在本指南的上下文中,用于对模型评分的标记数据集称为验证数据集。
未标记数据
未标记数据有特征列但没有目标列(没有答案),如下图所示:
需要未标记的数据来生成预测和解释。它必须具有与训练数据集完全相同的特征列,但没有目标列。在本指南的上下文中,用于预测和解释的未标记数据称为测试数据集。测试数据以标记数据开始,但为了试验机器学习模型而删除了标签。
您最终将与您的模型一起使用以进行预测的“看不见的数据”也是未标记的数据。与 测试数据集一样,未见过的数据必须具有与训练数据集完全相同的特征列,但没有目标列。
有关训练、验证和测试数据集表的示例及其结构,请参阅 示例数据和 第 6.3 节“Iris 数据集机器学习快速入门”。
一般数据要求
HeatWave ML 数据的一般要求包括以下内容:
-
每个数据集必须驻留在 MySQL 数据库系统的单个表中。HeatWave ML 例程,例如
ML_TRAIN
、ML_PREDICT_TABLE
和ML_EXPLAIN_TABLE
在单个表上运行。有关将数据加载到 MySQL 数据库系统的信息,请参阅 导入和导出数据库。
用于 HeatWave ML 的表不得超过 10 GB、1 亿行或 900 列。
表列必须使用受支持的数据类型。有关支持的数据类型和如何处理不支持的类型的建议,请参阅第 3.11 节,“支持的数据类型”。
NaN(非数字)值无法被 MySQL 识别,应替换为
NULL
.分类模型的训练数据集中的目标列必须至少有两个不同的值,并且每个不同的值应至少出现在五行中。对于回归模型,只允许使用数字目标列。
该
ML_TRAIN
例程将忽略其值缺失超过 20% 的列以及每行中具有相同值的列。数值列中的缺失值替换为列的平均值,标准化为均值为 0,标准差为 1。分类列中的缺失值替换为最常见的值,以及单热或有序编码用于将分类值转换为数值。MySQL 数据库中存在的输入数据不会被
ML_TRAIN
.
本指南中的示例使用人口普查收入数据集。
Dua, D. 和 Graff, C.(2019 年)。UCI 机器学习库 [ http://archive.ics.uci.edu/ml ]。加州欧文市:加州大学信息学院。
人口普查收入数据集示例演示了classification
训练和推理。HeatWave ML 还支持
regression
针对适合该目的的数据集进行训练和推理。该
参数定义机器学习模型是否针对
或
。
ML_TRAIN
task
classification
regression
要复制本指南中的示例,请执行以下步骤来创建所需的架构和表。需要 Python 3 和 MySQL Shell。
-
通过执行以下语句在 MySQL 数据库系统上创建以下模式和表:
CREATE SCHEMA heatwaveml_bench; USE heatwaveml_bench; CREATE TABLE census_train ( age INT, workclass VARCHAR(255), fnlwgt INT, education VARCHAR(255), `education-num` INT, `marital-status` VARCHAR(255), occupation VARCHAR(255), relationship VARCHAR(255), race VARCHAR(255), sex VARCHAR(255), `capital-gain` INT, `capital-loss` INT, `hours-per-week` INT, `native-country` VARCHAR(255), revenue VARCHAR(255)); CREATE TABLE `census_test` LIKE `census_train`;
在https://github.com/oracle-samples/heatwave-ml 导航到HeatWave ML 代码的性能基准GitHub 存储库 。
-
按照
README.md
说明创建census_train.csv
和census_test.csv
数据文件。总之,说明是:-
安装所需的 Python 包:
pip install pandas==1.2.3 numpy==1.22.2 unlzw3==0.2.1 sklearn==1.0.2
下载或克隆存储库,其中包括人口普查源数据和预处理脚本。
-
运行
preprocess.py
脚本以创建census_train.csv
和census_test.csv
数据文件。python3 heatwave-ml/preprocess.py --benchmark census
笔记不要按照
README.md
文件中的说明运行基准测试。基准测试脚本在处理结束时删除模式和数据。 -
-
--mysql
启动带有打开选项的 MySQL ShellClassicSession
,这在使用 并行表导入实用程序时是必需的。mysqlsh --mysql Username@IPAddressOfMySQLDBSystemEndpoint
-
.csv
使用以下命令 将文件中的数据加载到 MySQL 数据库系统中:MySQL>JS> util.importTable("census_train.csv",{table: "census_train", dialect: "csv-unix", skipRows:1}) MySQL>JS> util.importTable("census_test.csv",{table: "census_test", dialect: "csv-unix", skipRows:1})
-
创建验证表:
CREATE TABLE `census_validate` LIKE `census_test`; INSERT INTO `census_validate` SELECT * FROM `census_test`;
-
修改
census_test
表以删除目标`revenue`
列:ALTER TABLE `census_test` DROP COLUMN `revenue`;
有关可与 HeatWave ML 一起使用的其他示例数据集,请参阅HeatWave ML 性能基准代码GitHub 存储库, 网址为https://github.com/oracle-samples/heatwave-ml。