Dataset
MindSpore Transformers currently supports multiple types of dataset loading methods, covering common open-source and custom scenarios. Specifically, it includes:
Megatron Datasets: Supports loading datasets in the Megatron-LM format, suitable for large-scale language model pre-training tasks.
HuggingFace Datasets: Compatible with the HuggingFace datasets library, making it convenient to access a wide range of public data resources from the community.
MindRecord Datasets: MindRecord is an efficient data storage and reading module provided by MindSpore. This module offers various methods to help users convert different public datasets into the MindRecord format, as well as tools for reading, writing, and retrieving data from MindRecord files.
Megatron Dataset
Megatron dataset is an efficient data format designed for large-scale distributed language model pre-training scenarios, widely used within the Megatron-LM framework. These datasets are typically preprocessed and serialized into binary formats (such as .bin
or .idx
files), accompanied by specific indexing mechanisms to enable efficient parallel loading and data partitioning in distributed cluster environments.
The following sections will explain how to generate .bin
and .idx
files, as well as how to use Megatron datasets in training tasks.
Data Preprocessing
MindSpore Transformers provides a data preprocessing script, preprocess_indexed_dataset.py, which is used to convert raw text data in json
format into .bin
and .idx
files.
If the raw text data is not in json
format, users need to preprocess and convert it into the appropriate format themselves.
Below is an example of a json
format file:
{"src": "www.nvidia.com", "text": "The quick brown fox", "type": "Eng", "id": "0", "title": "First Part"}
{"src": "The Internet", "text": "jumps over the lazy dog", "type": "Eng", "id": "42", "title": "Second Part"}
...
The descriptions for each data field are as follows:
Field Name |
Description |
Required |
---|---|---|
text |
Raw text data |
Yes |
id |
Unique identifier (in order) |
No |
src |
Data source |
No |
type |
Language type |
No |
title |
Data title |
No |
The following example demonstrates how to convert the wikitext-103
dataset into a Megatron dataset format:
Download the
wikitext-103
dataset: LinkGenerate a
json
format data fileThe original text of the
wikitext-103
dataset looks like this:= Valkyria Chronicles III = Valkyria Chronicles III is a tactical role-playing game developed by Sega for the PlayStation Portable. The game was released in Japan on January 27, 2011. = Gameplay = The game is similar to its predecessors in terms of gameplay...
You need to preprocess the original text into the following format and save it as a
json
file:{"id": 0, "text": "Valkyria Chronicles III is a tactical role-playing game..."} {"id": 1, "text": "The game is similar to its predecessors in terms of gameplay..."} ...
Download the model's vocabulary file
Since different models use different vocabulary files, you need to download the corresponding vocabulary file for the training model. Taking the
Llama3
model as an example, download the tokenizer.model for data preprocessing.Generate
.bin
and.idx
data filesRun the data preprocessing script preprocess_indexed_dataset.py to convert the original text data into corresponding token IDs using the model's tokenizer.
The script accepts the following parameters:
Parameter Name
Description
input
Path to the
json
format fileoutput-prefix
Prefix for the
.bin
and.idx
data filestokenizer-type
Type of tokenizer used by the model
vocab-file
Path to the model’s tokenizer file (
tokenizer.model
/vocab.json
)merges-file
Path to the model’s tokenizer merges file (
merge.txt
)tokenizer-file
Path to the model’s tokenizer file (
tokenizer.json
)add_bos_token
Whether to add a
bos_token
(beginning of sequence token) to the vocabularyadd_eos_token
Whether to add an
eos_token
(end of sequence token) to the vocabularyeos_token
The string represent token
eos_token
, defaults to ''append-eod
Whether to add an
eos_token
to the end of documentationtokenizer-dir
The directory of HuggingFaceTokenizer, Take effects only when
tokenizer-type
='HuggingFaceTokenizer'trust-remote-code
Whether to allow for custom models defined in Hub. Take effects only when
tokenizer-type
='HuggingFaceTokenizer'register_path
Set the code directory of outer tokenizer. Take effects only when
tokenizer-type
='AutoRegister'auto_register
Set the import path of outer tokenizer. Take effects only when
tokenizer-type
='AutoRegister'The optional value of
tokenizer-type
is 'HuggingFaceTokenizer' and 'AutoRegister'. When it's set to 'HuggingFaceTokenizer',AutoTokenizer
class intransformers
library will instantiate tokenizer in local HuggingFace repository. When it's set to 'AutoRegister', outer tokenizer class specified byregister_path
andauto_register
will be applied.Take LlamaTokenizerFast and vocab file in DeepSeek-V3 repository as an example. If there is no corresponding repository, configuration file (tokenizer_config.json) and vocab file (tokenizer.json) need to be download to local path. Let it be /path/to/huggingface/tokenizer. Execute the following command to preprocess the dataset:
python toolkit/data_preprocess/megatron/preprocess_indexed_dataset.py \ --input /path/data.json \ --output-prefix /path/megatron_data \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-dir /path/to/huggingface/tokenizer
Take outer tokenizer class Llama3Tokenizer as an example, make sure local MindSpore Transformers repository has 'research/llama3_1/llama3_1_tokenizer.py', and execute the following command to preprocess the dataset:
python toolkit/data_preprocess/megatron/preprocess_indexed_dataset.py \ --input /path/data.json \ --output-prefix /path/megatron_data \ --tokenizer-type AutoRegister \ --vocab-file /path/tokenizer.model \ --register_path research/llama3_1 \ --auto_register llama3_1_tokenizer.Llama3Tokenizer
Model Pre-training
MindSpore Transformers recommends using Megatron datasets for model pre-training. Based on the Data Preprocessing steps, you can generate the required pre-training dataset. The following explains how to configure and use Megatron datasets in the configuration file.
Prepare the
parallel_speed_up.json
fileMegatron dataset relies on the
dataset_broadcast_opt_level
feature for data broadcasting. For more details, refer to the documentation. Therefore, you need to create aparallel_speed_up.json
file with the following content:{ "dataset_broadcast_opt_level": 3 }
At the same time, add the following fields to the model configuration file:
context: ascend_config: parallel_speed_up_json_path: "/path/to/parallel_speed_up.json"
Modify the model configuration file
To use the Megatron dataset in model pre-training tasks, mainly modify the
train_dataset
section in the configuration file.train_dataset: &train_dataset data_loader: type: BlendedMegatronDatasetDataLoader datasets_type: "GPTDataset" sizes: - 1000 # Number of training dataset samples - 0 # Number of testing dataset samples (currently unsupported) - 0 # Number of evaluation dataset samples (currently unsupported) config: # GPTDataset configuration options seed: 1234 # Random seed for data sampling split: "1, 0, 0" # Ratio of training, testing, and evaluation datasets (currently unsupported) seq_length: 8192 # Sequence length of data returned by the dataset eod_mask_loss: True # Whether to compute loss at end-of-document (EOD) tokens reset_position_ids: True # Whether to reset position_ids at EOD tokens create_attention_mask: True # Whether to return attention_mask reset_attention_mask: True # Whether to reset attention_mask at EOD tokens, returning a staircase-shaped mask create_compressed_eod_mask: False # Whether to return a compressed attention_mask eod_pad_length: 128 # Length of the compressed attention_mask eod: 0 # Token ID of the EOD token in the dataset pad: 1 # Token ID of the pad token in the dataset data_path: # Sampling ratio and paths for Megatron datasets - '0.3' # Ratio of dataset1 - "/path/megatron_data1" # Path to bin file of dataset1 excluding the .bin suffix - '0.7' # Ratio of dataset2 - "/path/megatron_data2" # Path to bin file of dataset2 excluding the .bin suffix input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] parallel: full_batch: False dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1, 1, 1]] # *dp means same value as data_parallel model_config: input_sliced_sig: True
Below are the descriptions for each configuration option of the
GPTDataset
in the dataset:Parameter Name
Description
seed
Random seed for dataset sampling. Megatron datasets use this value to randomly sample and concatenate samples. Default:
1234
seq_length
Sequence length of data returned by the dataset. Should be consistent with the sequence length of the training model.
eod_mask_loss
Whether to compute loss at the end-of-document (EOD) token. Default:
False
create_attention_mask
Whether to return an attention_mask. Default:
True
reset_attention_mask
Whether to reset the attention_mask at EOD tokens, returning a staircase-shaped attention_mask. Effective only if
create_attention_mask=True
. Default:False
create_compressed_eod_mask
Whether to return a compressed attention_mask. Has higher priority than
create_attention_mask
. Default:False
eod_pad_length
Length of the compressed attention_mask. Effective only if
create_compressed_eod_mask=True
. Default:128
eod
Token ID of the EOD token in the dataset
pad
Token ID of the pad token in the dataset
data_path
List, every two consecutive elements (number, string) are considered as a dataset, represent ratio of the dataset and the path to its bin file excluding
.bin
suffix respectively. The sum of datasets' ratios should be equal to 1.In addition, the Megatron dataset also depends on configurations such as
input_columns
,construct_args_key
, andfull_batch
. For more details, refer to the configuration file documentation.Here, we only explain how to configure them in different scenarios:
When
create_compressed_eod_mask=True
:
train_dataset: &train_dataset input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "actual_seq_len"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "actual_seq_len"] parallel: full_batch: False dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1]] # *dp means same value as data_parallel
When
create_compressed_eod_mask=False
andcreate_attention_mask=True
:
train_dataset: &train_dataset input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] parallel: full_batch: False dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1, 1, 1]] # *dp means same value as data_parallel
When
create_compressed_eod_mask=False
andcreate_attention_mask=False
:
train_dataset: &train_dataset input_columns: ["input_ids", "labels", "loss_mask", "position_ids"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids"] parallel: full_batch: False dataset_strategy: [[*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1]] # *dp means same value as data_parallel
Start Model Pre-training
After modifying the dataset and parallel-related configurations in the model configuration file, you can refer to the model documentation to launch the model pre-training task. Here, we take the Llama3_1 model documentation as an example.
Hugging Face Dataset
The HuggingFace Dataset (HF Dataset) module is integrated with the HuggingFace community, providing efficient and flexible HF dataset loading and processing. Main features include:
Diverse Data Loading: Supports various formats and loading methods from the Hugging Face
datasets
library, easily adapting to different sources and structures.Rich Data Processing Interfaces: Compatible with multiple processing methods from the
datasets
library (such assort
,flatten
,shuffle
, etc.), meeting common preprocessing needs.Extensible Data Operations: Supports user-defined dataset processing logic and provides efficient packing functionality for large-scale training optimization.
To use HuggingFace datasets in MindSpore Transformers, you need to understand the basic functionalities of the
datasets
third-party library, such as dataset loading and processing. For more details, please refer to this link.If the Python version is less than 3.10, you need to install a version of aiohttp below 3.8.1.
Configuration
To use HF dataset functionality in model training, modify the data_loader
configuration:
train_dataset: &train_dataset
input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
data_loader:
type: HFDataLoader
# datasets load arguments
load_func: 'load_dataset'
path: "json"
data_files: "/path/alpaca-gpt4-data.json"
split: "train"
# MindSpore Transformers dataset arguments
create_attention_mask: True
create_compressed_eod_mask: False
compressed_eod_mask_length: 128
use_broadcast_data: True
shuffle: False
# dataset process arguments
handler:
- type: AlpacaInstructDataHandler
seq_length: 4096
padding: False
tokenizer:
pretrained_model_dir: '/path/qwen3'
trust_remote_code: True
padding_side: 'right'
- type: PackingHandler
seq_length: 4096
pack_strategy: 'pack'
num_parallel_workers: 8
python_multiprocessing: False
drop_remainder: True
numa_enable: False
prefetch_size: 1
seed: 1234
All examples use
seq_length
,tokenizer
, etc., from theqwen3
model.
data_loader
parameter descriptions:
Parameter |
Description |
Type |
---|---|---|
type |
Fixed as |
str |
load_func |
Specifies the dataset loading interface, options are |
str |
create_attention_mask |
Whether to return attention mask during dataset iteration; default is |
bool |
create_compressed_eod_mask |
Whether to return compressed one-dimensional attention mask during iteration; default is |
bool |
compressed_eod_mask_length |
Length of compressed attention mask, usually the max number of eod tokens in samples; default is |
int |
use_broadcast_data |
Whether to enable data broadcast; default is |
bool |
shuffle |
Whether to randomly sample the dataset; default is |
bool |
handler |
Data preprocessing operations. For details, refer to the Dataset Processing section |
list |
Dataset Loading
The dataset loading functionality is mainly implemented through the load_func
parameter.
HFDataLoader
will pass all parameters (except those defined in Configuration) as input arguments to the dataset loading interface. The detailed usage is as follows:
Using the
datasets.load_dataset
interface to load datasets:In the dataset configuration, set
load_func: 'load_dataset'
, and configure the following parameters:path (str) — Path or name of the dataset directory.
If
path
is a local directory, the dataset will be loaded from the supported files (csv, json, parquet, etc.) in that directory. Example:'/path/json/'
.If
path
is the name of a dataset builder anddata_files
ordata_dir
is specified (available builders include"json"
,"csv"
,"parquet"
,"arrow"
, etc.), the dataset will be loaded from the files indata_files
ordata_dir
.
data_dir (str, optional) — When
path
is set to the name of a dataset builder, this specifies the dataset directory path.data_files (str, optional) — When
path
is set to the name of a dataset builder, this specifies the dataset file path(s). It can be a single file or a list of multiple file paths.split (str) — The data split to load. If set to
None
, a dictionary containing all splits will be returned (typicallydatasets.Split.TRAIN
anddatasets.Split.TEST
). If specified, the corresponding split will be returned as aDataset
instance.
Using the
datasets.load_from_disk
interface to load datasets:In the dataset configuration, set
load_func: 'load_from_disk'
, and configure the following parameter:dataset_path (str) — Path to the dataset directory. This interface is typically used to load datasets that have been preprocessed offline or saved using
datasets.save_to_disk
.
Dataset Processing
HFDataLoader
supports native datasets processing and user-defined operations, mainly via the handler
mechanism, which executes preprocessing steps in order.
Native Processing
To rename dataset columns, remove columns, or randomly sample the dataset, you can configure as follows:
handler:
- type: 'rename_column'
original_column_name: 'col1'
new_column_name: 'col2'
- type: 'remove_columns'
column_names: 'col2'
- type: 'shuffle'
seed: 42
rename_column - Rename a column
Renames
col1
tocol2
.remove_columns - Remove a column
Removes
col2
.shuffle - Shuffle the dataset
Shuffles with seed 42.
For other native dataset processing operations, please refer to the datasets process documentation.
Custom Processing
To use custom preprocessing, implement your own handler module. See AlpacaInstructDataHandler.
Custom handlers support Class
and Method
forms:
If using a Class
:
Implement a class with a call function:
class CustomHandler: def __init__(self, seed): self.seed = seed def __call__(self, dataset): dataset = dataset.shuffle(seed=self.seed) return dataset
The
CustomHandler
above implements the random sampling of the dataset. To achieve other functions, you can modify the data preprocessing operations and return the processed dataset.MindSpore Transformers provides BaseInstructDataHandler with built-in tokenizer config. If need to use a tokenizer, you can inherit from the
BaseInstructDataHandler
class.Add to __init__.py:
from .custom_handler import CustomHandler
Use in config:
handler: - type: CustomHandler seed: 42
If using a Method
:
Implement a function with dataset as input:
def custom_process(dataset, seed): dataset = dataset.shuffle(seed) return dataset
Add to __init__.py:
from .custom_handler import custom_process
Use in config:
handler: - type: custom_process seed: 42
Practical Application
Below, we will use the qwen3
model and the alpaca
dataset as examples to demonstrate how to fine-tune the HF dataset. The AlpacaInstructDataHandler
will be used for online data processing. The specific parameter descriptions are as follows.
seq_length: Maximum length for encoding text to token IDs via tokenizer; usually matches model training sequence length.
padding: Whether to pad token IDs to max length during encoding.
tokenizer:
pretrained_model_dir
is the folder with model vocab and weights from HF.trust_remote_code
is usually set toTrue
, andpadding_side
indicates that padding is applied from the right side of the token ID.
Alpaca Dataset Fine-tuning
For qwen3
model fine-tuning, modify the training config:
train_dataset: &train_dataset
input_columns: ["input_ids", "labels"]
construct_args_key: ["input_ids", "labels"]
data_loader:
type: HFDataLoader
# datasets load arguments
load_func: 'load_dataset'
path: 'json'
data_files: '/path/alpaca-gpt4-data.json'
# MindSpore Transformers dataset arguments
use_broadcast_data: True
shuffle: False
# dataset process arguments
handler:
- type: AlpacaInstructDataHandler
seq_length: 4096
padding: True
tokenizer:
pretrained_model_dir: '/path/qwen3' # qwen3 repo dir
trust_remote_code: True
padding_side: 'right'
num_parallel_workers: 8
python_multiprocessing: False
drop_remainder: True
numa_enable: False
prefetch_size: 1
seed: 1234
context:
ascend_config:
parallel_speed_up_json_path: "configs/qwen3/parallel_speed_up.json"
parallel_config:
data_parallel: &dp 2
parallel:
full_batch: False
dataset_strategy: [
[*dp, 1],
[*dp, 1]
] # *dp = data_parallel
See Megatron Dataset for details on parallel_speed_up_json_path
, dataset_strategy
, etc.
After modifying the configuration file, refer to the qwen3
model documentation to initiate a fine-tuning task that loads offline data.
Alpaca Dataset Packing Fine-tuning
MindSpore Transformers implements the dataset packing functionality, which is mainly used in large-scale model training tasks to concatenate multiple short sequences into fixed-length long sequences, thereby improving training efficiency. It currently supports two strategies, which can be configured through pack_strategy
:
pack: Concatenates multiple samples into a fixed-length sequence. When the sample to be concatenated exceeds the maximum length
seq_length
, the sample is placed into the next concatenated sequence.truncate: Concatenates multiple samples into a fixed-length sequence. When the sample to be concatenated exceeds the maximum length
seq_length
, the sample is truncated, and the remaining part is placed into the next concatenated sequence.
This functionality is implemented through the PackingHandler
class. The final output only contains three fields: input_ids
, labels
, and actual_seq_len
.
For packing fine-tuning with qwen3
, modify the training config:
train_dataset: &train_dataset
input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"]
data_loader:
type: HFDataLoader
# datasets load arguments
load_func: 'load_dataset'
path: 'json'
data_files: '/path/alpaca-gpt4-data.json'
# MindSpore Transformers dataset arguments
use_broadcast_data: True
shuffle: False
# dataset process arguments
handler:
- type: AlpacaInstructDataHandler
seq_length: 4096
padding: False
tokenizer:
pretrained_model_dir: '/path/qwen3' # qwen3 repo dir
trust_remote_code: True
padding_side: 'right'
- type: PackingHandler
seq_length: 4096
pack_strategy: 'pack'
num_parallel_workers: 8
python_multiprocessing: False
drop_remainder: True
numa_enable: False
prefetch_size: 1
seed: 1234
context:
ascend_config:
parallel_speed_up_json_path: "configs/qwen3/parallel_speed_up.json"
parallel_config:
data_parallel: &dp 2
parallel:
full_batch: False
dataset_strategy: [
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1],
[*dp, 1, 1, 1]
] # *dp = data_parallel
After modifying the config, refer to the qwen3
model documentation to start fine-tuning.
Offline Processing for Alpaca Data Fine-tuning
HFDataLoader
supports offline processing and saving of HF datasets; processed data can be loaded directly for training.
Modify the
qwen3
training config:train_dataset: &train_dataset data_loader: type: HFDataLoader # datasets load arguments load_func: 'load_dataset' path: 'json' data_files: '/path/alpaca-gpt4-data.json' # dataset process arguments handler: - type: AlpacaInstructDataHandler seq_length: 4096 padding: False tokenizer: pretrained_model_dir: '/path/qwen3' # qwen3 repo dir trust_remote_code: True padding_side: 'right' - type: PackingHandler seq_length: 4096 pack_strategy: 'pack'
Run the preprocessing script:
python toolkit/data_preprocess/huggingface/datasets_preprocess.py --config configs/qwen3/pretrain_qwen3_32b_4k.yaml --save_path processed_dataset/
Modify the config:
train_dataset: &train_dataset input_columns: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] construct_args_key: ["input_ids", "labels", "loss_mask", "position_ids", "attention_mask"] data_loader: type: HFDataLoader # datasets load arguments load_func: 'load_from_disk' dataset_path: '/path/processed_dataset' # MindSpore Transformers dataset arguments create_attention_mask: True use_broadcast_data: True shuffle: False num_parallel_workers: 8 python_multiprocessing: False drop_remainder: True numa_enable: False prefetch_size: 1 seed: 1234 context: ascend_config: parallel_speed_up_json_path: "configs/qwen3/parallel_speed_up.json" parallel_config: data_parallel: &dp 2 parallel: full_batch: False dataset_strategy: [ [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1], [*dp, 1, 1, 1] ] # *dp = data_parallel
After modifying the configuration file, refer to the
qwen3
model documentation to initiate a fine-tuning task that loads offline data.
MindRecord Dataset
MindRecord is an efficient data storage and reading module provided by MindSpore. It reduces disk IO and network IO overhead, resulting in a better data loading experience. For more detailed feature introductions, refer to the documentation. Here, we only cover how to use MindRecord in MindSpore Transformers model training tasks.
The following example uses qwen2_5-0.5b
fine-tuning to explain related functionalities. The provided scripts are only applicable to the specified dataset. If you need to process a custom dataset, please refer to MindRecord format conversion for data preprocessing.
Data Preprocessing
Download the
alpaca
dataset: LinkExecute the data processing script to convert the
alpaca
dataset into a dialogue format:python research/qwen2/alpaca_converter.py \ --data_path /path/alpaca_data.json \ --output_path /path/alpaca-data-messages.json
Here,
data_path
refers to the path where the downloadedalpaca
dataset is stored, andoutput_path
refers to the save path for the generated dialogue format data file.Execute the script to convert the dialogue format data file into MindRecord format:
python research/qwen2/qwen2_preprocess.py \ --dataset_type 'qa' \ --input_glob /path/alpaca-data-messages.json \ --vocab_file /path/vocab.json \ --merges_file /path/merges.txt \ --seq_length 32768 \ --output_file /path/alpaca-messages.mindrecord
The script parameters are explained as follows:
dataset_type
: Type of data preprocessing. For the alpaca dataset, set this toqa
.input_glob
: Path to the dialogue format data file.vocab_file
: Path to thevocab.json
file of the qwen2 model.merges_file
: Path to themerges.txt
file of the qwen2 model.seq_length
: Sequence length for generating MindRecord data.output_file
: Save path for the generated MindRecord data.
The
vocab_file
andmerges_file
can be obtained from the qwen2 model repository on the HuggingFace community.
Model Fine-tuning
Following the above data preprocessing steps, you can generate a MindRecord dataset for fine-tuning the qwen2_5-0.5b
model. Below is an introduction on how to use the generated data file to start the model fine-tuning task.
Modify the model configuration file
The
qwen2_5-0.5b
model fine-tuning uses the finetune_qwen2_5_0.5b_8k.yaml configuration file. Modify the dataset section as follows:train_dataset: &train_dataset data_loader: type: MindDataset dataset_dir: "/path/alpaca-messages.mindrecord" shuffle: True
When using the MindRecord dataset in a model training task, the following configurations in
data_loader
need to be modified:type
: Type of data_loader. Set toMindDataset
when using MindRecord datasets.dataset_dir
: Path to the MindRecord data files.shuffle
: Whether to randomly sample data samples during training.
Start Model Fine-tuning
After modifying the dataset and parallel-related configurations in the model configuration file, you can refer to the model documentation to launch the fine-tuning task. Here, we take the Qwen2_5 model documentation as an example.
Multi-source Datasets
The native MindSpore dataset loading module MindDataset has performance bottlenecks when loading and sampling multiple MindRecord datasets.
Therefore, MindSpore Transformers implements the MultiSourceDataLoader
to achieve efficient loading and sampling across multiple datasets.
The multi-source dataset functionality is mainly enabled by modifying the data_loader
configuration in the config file. Below is an example:
train_dataset: &train_dataset
data_loader:
type: MultiSourceDataLoader
data_source_type: random_access
shuffle: True
dataset_ratios: [0.2, 0.8]
samples_count: 1000
nums_per_dataset: [2000]
sub_data_loader_args:
stage: 'train'
column_names: ["input_ids", "target_ids", "attention_mask"]
sub_data_loader:
- type: MindDataset
dataset_files: "/path/alpaca-messages.mindrecord"
- type: MindDataset
dataset_files: "/path/alpaca-messages.mindrecord"
load_indices_npz_path: '/path/index.npz'
save_indices_npz_path: '/path/index.npz'
The shuffle
setting affects two parameters: shuffle_dataset
and shuffle_file
:
shuffle_dataset
indicates random sampling at the sub-dataset level.shuffle_file
indicates random sampling at the sample level.
The effects of different shuffle
values are as follows:
shuffle |
shuffle_dataset |
shuffle_file |
---|---|---|
True |
True |
True |
False |
False |
False |
infile |
False |
True |
files |
True |
False |
global |
True |
True |
Other configuration parameters are explained below:
Parameter |
Description |
Type |
---|---|---|
dataset_ratios |
Sampling ratios for each sub-dataset; sum of all equals 1 |
list |
samples_count |
Number of samples from each sub-dataset, effective only when |
int |
nums_per_dataset |
Number of samples per sub-dataset, effective when |
list |
sub_data_loader_args |
Common configurations for each sub-dataset, effective during sub-dataset construction |
dict |
sub_data_loader |
Configuration for each sub-dataset, same as |
list |
load_indices_npz_path |
Path to load data index file |
str |
save_indices_npz_path |
Path to save data index file |
str |