1 Star 0 Fork 538

陈一 / ModelLink

forked from Ascend / ModelLink 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
pretrain_gpt.py 8.94 KB
一键复制 编辑 原始数据 按行查看 历史
guhangsong 提交于 2024-01-30 02:13 . !480 支持指令微调功能
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import os
from functools import partial
from typing import Union
import torch
import torch_npu
from torch import Tensor
from ascendspeed import megatron_adaptor
from torch_npu.contrib import transfer_to_npu
import modellink
from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig
from megatron.core.datasets.gpt_dataset import GPTDataset
import megatron.model
from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain
from megatron.core.transformer.spec_utils import import_module
from megatron.utils import (
get_ltor_masks_and_position_ids,
get_batch_on_this_cp_rank,
average_losses_across_data_parallel_group
)
from megatron.arguments import core_transformer_config_from_args
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_with_transformer_engine_spec,
gpt_layer_with_transformer_engine_spec_moe
)
from modellink.data.decoder_packed_mtf_dataset import build_train_valid_test_datasets as build_instruction_dataset
from modellink.utils import get_tune_attention_mask
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.model.GPTModel]:
"""Builds the model.
If you set the use_mcore_models to True, it will return the mcore GPT model and if not the legacy GPT model.
Args:
pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True.
post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True.
Returns:
Union[GPTModel, megatron.model.GPTModel]: The returned model
"""
args = get_args()
print_rank_0('building GPT model ...')
config = core_transformer_config_from_args(get_args())
if args.use_mcore_models:
if args.spec is not None:
transformer_layer_spec = import_module(args.spec)
else:
if args.num_experts is None:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec()
else:
transformer_layer_spec = gpt_layer_with_transformer_engine_spec_moe
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent
)
else:
if not args.context_parallel_size == 1:
raise ValueError("Context parallelism is only supported with Megatron Core!")
model = megatron.model.GPTModel(
config,
num_tokentypes=0,
parallel_output=True,
pre_process=pre_process,
post_process=post_process
)
return model
def get_batch(data_iterator):
"""Generate a batch."""
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
args = get_args()
tokenizer = get_tokenizer()
if args.is_instruction_dataset:
# Items and their type.
keys = ['input_ids', 'attention_mask', 'labels']
data_type = torch.int64
# Broadcast data.
data_b = tensor_parallel.broadcast_data(keys, next(data_iterator), data_type)
# Unpack
labels = data_b.get('labels').long()
tokens = data_b.get('input_ids').long()
attention_mask_1d = data_b.get('attention_mask').long()
# ignored label -100
loss_mask = torch.where(labels == -100, 0, 1)
attention_mask = get_tune_attention_mask(attention_mask_1d, args.reset_attention_mask)
return tokens, labels, loss_mask, attention_mask, None
# Items and their type.
keys = ['text']
datatype = torch.int64
# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack.
tokens_ = data_b['text'].long()
labels = tokens_[:, 1:].contiguous()
tokens = tokens_[:, :-1].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens,
tokenizer.eod,
args.reset_position_ids,
args.reset_attention_mask,
args.eod_mask_loss)
batch = {
'tokens': tokens,
'labels': labels,
'loss_mask': loss_mask,
'attention_mask': attention_mask,
'position_ids': position_ids
}
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def loss_func(loss_mask: Tensor, output_tensor: Tensor):
"""Loss function.
Args:
loss_mask (Tensor): Used to mask out some portions of the loss
output_tensor (Tensor): The tensor with the losses
"""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
if args.context_parallel_size > 1:
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
loss = loss[0] / loss[1]
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
if loss.isnan():
raise ValueError(f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}')
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss * args.context_parallel_size, {'lm loss': averaged_loss[0]}
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
timers('batch-generator').stop()
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def is_dataset_built_on_rank():
return (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and mpu.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig(
is_built_on_rank=is_dataset_built_on_rank,
random_seed=args.seed,
sequence_length=args.seq_length,
blend=args.data_path,
blend_per_split=[args.train_data_path, args.valid_data_path, args.test_data_path],
split=args.split,
path_to_cache=args.data_cache_path,
)
def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
print_rank_0("> building train, validation, and test datasets for GPT ...")
if args.is_instruction_dataset:
train_ds, valid_ds, test_ds = build_instruction_dataset(
data_prefix=args.data_path,
splits_string=args.split,
train_valid_test_num_samples=train_val_test_num_samples,
seq_length=args.seq_length,
seed=args.seed)
else:
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
GPTDataset,
train_val_test_num_samples,
core_gpt_dataset_config_from_args(args)
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
if __name__ == "__main__":
torch.npu.set_compile_mode(jit_compile=True)
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
pretrain(train_valid_test_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})
Python
1
https://gitee.com/cycychenyi/ModelLink.git
git@gitee.com:cycychenyi/ModelLink.git
cycychenyi
ModelLink
ModelLink
modellink

搜索帮助