当前位置: 首页 > news >正文

网站建设.c重庆搜索引擎推广公司

网站建设.c,重庆搜索引擎推广公司,投资理财网站开发,那个网站可以做家具效果图序言 本文总结一下目前TRL典型的训练器的实现细节#xff08;SFT#xff0c;PPO#xff0c;DPO#xff0c;GRPO#xff09;#xff0c;也是对上一文【速写】PPOTrainer样例与错误思考#xff08;少量DAPO#xff09;的补充 目前DeepSeek关于各个训练器细节的掌握SFTPPODPOGRPO也是对上一文【速写】PPOTrainer样例与错误思考少量DAPO的补充 目前DeepSeek关于各个训练器细节的掌握尤其是PPOTrainer的问题依然回答得很差这个在上文中已经详细指出它写的代码大多数都是跑不通的而官方给出的PPO示例ppo.py似乎也有一些瑕疵本文将会一一指出与各位探讨。 目前笔者写了一个对于4个典型训练器SFTPPODPOGRPO都适用的一个base_pipeline以及对应的单元测试模块其中PPOTrainer是最为繁琐的一个训练器它需要更多的模型奖励模型价值模型参考模型并且对应的训练数据集train_dataset的处理方式与其他训练器存在显著区别。根据这份base_pipeline的代码我将逐一探讨TRL训练器中的细节问题谨以抛砖引玉欢迎探讨。 文章目录 序言1 TRL通用的训练器Pipeline测试及其细节说明1.1 训练配置参数的设置Config1.2 加载分词器与模型tokenizer model1.3 数据集分割与字段问题dataset1.4 训练与保存checkpoint1.5 单元测试 2 一些其他的问题2.1 DataProcessor与DataCollator2.1.1 dataset.map(prompt_formatter):2.1.2 data_collator:2.1.3 何时选择哪种方式 2.2 关于PartialState2.2.1 核心功能2.2.2 典型使用场景2.2.3 参数与底层机制2.2.4 与类似方法的区别2.2.5 完整示例分布式训练中的数据加载2.2.6 注意事项2.2.7 总结 2.3 DPO和PPO的reference_model的区别2.3.1 DPO的核心思想与Reference Model的作用2.3.2 **为什么DPO需要Reference Model尽管它不是RL**2.3.3 TRL的DPOTrainer中的ref_model2.3.4 **DPO vs PPO的Reference Model**2.3.5 **如果没有Reference Model会怎样** 1 TRL通用的训练器Pipeline测试及其细节说明 完整的项目在GitHubcaoyang-sufe/easyllm以下我们先来看下面代码中的base_pipeline其他4个pipelinesft_pipelineppo_pipelinedpo_pipelinegrpo_pipeline都是直接调用它使用。 对应的脚本在trainer.py 这个pipeline很重要接下来所有的讨论都是围绕它逐行展开的。 # -*- coding: utf8 -*- # author: caoyang # email: caoyangstu.sufe.edu.cnimport wandb import logging from copy import deepcopy from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, HfArgumentParser from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model from trl import (ScriptArguments, ModelConfig, SFTConfig, SFTTrainer,PPOConfig, PPOTrainer,DPOConfig, DPOTrainer,GRPOConfig, GRPOTrainer,get_peft_config, get_quantization_config, ) from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE from src.tools.trl import update_trl_config, generate_simple_data_processor# Trainer Pipeline # param name: [Str] e.g. SFT, PPO, DPO, GRPO # param data_processor: Function object prepared for dataset.map(data_processor) # param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. # param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. # param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. dataset_name, dataset_train_split, dataset_test_split # param config_kwargs: [Dict] keyword arguments for updating TRL-Config, ScriptArguments, ModelConfig # - keyword arguments for TRLConfig: e.g. output_dir, adam_xxx, learning_rate, kl_coef, push_to_hub # - keyword arguments for ScriptArguments: e.g. output_dir, adam_xxx, learning_rate, kl_coef, push_to_hub # - keyword arguments for ModelConfig: e.g. model_name_or_path, torch_dtype, trust_remote_code, use_peft, lora_xxx, load_in_4bit, bnb_4bit_compute_dtype, bnb_4bit_quant_type # param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer # - keyword arguments for all Trainers: e.g. data_collator, callbacks # - keyword arguments for SFTTrainer: e.g. compute_loss_func, compute_metrics # - keyword arguments for PPOTrainer: e.g. ref_model[required], reward_model[required], value_model[required] # - keyword arguments for DPOTrainer: e.g. ref_model # - keyword arguments for GRPOTrainer: e.g. reward_funcs[required] def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer eval(f{name}Config), eval(f{name}Trainer)parser HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config parser.parse_args_into_dataclasses()script_arguments update_trl_config(script_arguments, **config_kwargs)trainer_config update_trl_config(trainer_config, **config_kwargs)model_config update_trl_config(model_config, **config_kwargs)peft_config get_peft_config(model_config)quantization_config get_quantization_config(model_config)# 2 Load models and tokenizerlogging.info(Load models and tokenizer ...)logging.info(f - Model: {model_config.model_name_or_path})tokenizer AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not pad_token in tokenizer.special_tokens_map:tokenizer.add_special_tokens({pad_token: [PAD]})if tokenizer.chat_template is None:tokenizer.chat_template SIMPLE_CHAT_TEMPLATEmodel AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map auto,trust_remote_code model_config.trust_remote_code,quantization_config quantization_config,)if peft_config is not None:logging.info(Prepare model for PEFT ...)model.config.pretraining_tp 1model.config.use_cache Falsemodel.gradient_checkpointing_enable()# If prepare_model_for_kbit_training is ignored, and gradient_checkpointing True (for GPU memory saving)# Then you need set model.enable_input_require_grads() yourself# model prepare_model_for_kbit_training(model)model.enable_input_require_grads()model get_peft_model(model, peft_config)if name PPO:logging.info(PPO load reward value and reference models ...)# PPO is special! It needs more components!logging.info(f - Reward model: {trainer_config.reward_model_path})reward_model AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code model_config.trust_remote_code,num_labels 1,)value_model AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code model_config.trust_remote_code,num_labels 1,)logging.info( - Copy reference model ...)ref_model deepcopy(model)# ref_model model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs[reward_model] reward_modeltrainer_kwargs[value_model] value_modeltrainer_kwargs[ref_model] ref_modellogging.info( - Done!)if data_processor is None:# The data processor of PPO is also different to othersdef data_processor(_data):outputs tokenizer(_data[prompt] _data[completion], padding False)return {input_ids: outputs[input_ids]}# 2 Load datasetlogging.info(Load dataset ...)logging.info(f - Dataset: {script_arguments.dataset_name})if data_processor is None:data_processor generate_simple_data_processor(name)train_dataset load_dataset(script_arguments.dataset_name, splitscript_arguments.dataset_train_split)eval_dataset load_dataset(script_arguments.dataset_name, splitscript_arguments.dataset_test_split)train_dataset train_dataset.map(data_processor, remove_columnstrain_dataset.column_names)eval_dataset eval_dataset.map(data_processor, remove_columnseval_dataset.column_names)logging.info(f - Train dataset: {len(train_dataset)})logging.info(f - Eval dataset: {len(eval_dataset)})# 4 Train modellogging.info(Trainer starts ...)trainer TRLTrainer(model model,args trainer_config,train_dataset train_dataset,eval_dataset eval_dataset,processing_class tokenizer,peft_config peft_config,**trainer_kwargs)trainer.train()logging.info( - Trainer finishes!)# 5 Save modelif trainer_config.push_to_hub:logging.info(f - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id})trainer.push_to_hub()logging.info(fSave model to {trainer_config.output_dir})trainer.save_model(trainer_config.output_dir) # SFT Pipeline def sft_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name SFT,data_processor data_processor,config_kwargs config_kwargs,trainer_kwargs trainer_kwargs,) # PPO Pipeline def ppo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name PPO,data_processor data_processor,config_kwargs config_kwargs,trainer_kwargs trainer_kwargs,) # DPO Pipeline def dpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name DPO,data_processor data_processor,config_kwargs config_kwargs,trainer_kwargs trainer_kwargs,) # GRPO Pipeline def grpo_pipeline(data_processor, config_kwargs, trainer_kwargs):base_pipeline(name GRPO,data_processor data_processor,config_kwargs config_kwargs,trainer_kwargs trainer_kwargs,)对应的单元测试脚本在trainer_pipelines.py这里面涉及3个模型和4个数据集都可以在huggingface上直接下载得到model_home和dataset_home可根据本地路径进行修改。 # -*- coding: utf8 -*- # author: caoyang # email: caoyangstu.sufe.edu.cnimport os import logging from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home /nfsshare/home/caoyang/resource/model dataset_home /nfsshare/home/caoyang/resource/dataset model_names [Qwen/Qwen2.5-0.5B-Instruct,EleutherAI/pythia-1b-deduped,EleutherAI/pythia-160m, ]dataset_names [trl-lib/tldr, # train[prompt, completion] validation[prompt, completion] test[prompt, completion]trl-lib/ultrafeedback_binarized, # train[chosen, rejected, score_chosen, score_rejected] test[chosen, rejected, score_chosen, score_rejected]trl-internal-testing/descriptiveness-sentiment-trl-style, # sentiment[prompt, chosen, rejected] descriptiveness[prompt, chosen, rejected]YeungNLP/firefly-train-1.1M, # train[input, target] ]def sft_pipeline_test():logging.info(SFT unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[0])data_processor Noneconfig_kwargs {output_dir: f./temp/sft{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[500:600],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info(PPO unittest ...)model_name_or_path os.path.join(model_home, model_names[1])EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p0.0, inplaceFalse)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(post_attention_layernorm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(post_attention_dropout): Dropout(p0.0, inplaceFalse)(post_mlp_dropout): Dropout(p0.0, inplaceFalse)(attention): GPTNeoXAttention((query_key_value): Linear(in_features2048, out_features6144, biasTrue)(dense): Linear(in_features2048, out_features2048, biasTrue))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features2048, out_features8192, biasTrue)(dense_4h_to_h): Linear(in_features8192, out_features2048, biasTrue)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features2048, out_features50304, biasFalse))dataset_name os.path.join(dataset_home, dataset_names[0])reward_model_path os.path.join(model_home, model_names[2])data_processor Noneconfig_kwargs {output_dir: f./temp/ppo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,reward_model_path: reward_model_path,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[:100],use_peft: True,report_to: none,lora_target_modules: [query_key_value],}trainer_kwargs {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info(DPO unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[2])data_processor Noneconfig_kwargs {output_dir: f./temp/dpo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: descriptiveness[:500],dataset_test_split: descriptiveness[500:600],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info(GRPO unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[0])data_processor Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs {output_dir: f./temp/grpo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[:100],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {reward_funcs: reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)目前单元测试都能通过关键库版本如下 accelerate1.6.0 datasets3.5.0 peft0.15.2 torch2.5.1 transformers4.51.3 trl0.17.01.1 训练配置参数的设置Config 目前模型训练涉及的参数非常的广泛以前搭积木时代训练模型时我们一般会自定义一个Config类用于管理与模型、训练、数据集等相关的参数但目前Transformers提供了非常好的工具HfArgumentParser用于管理这些繁杂的参数根据官方给出的PPO示例ppo.py中的写法 parser HfArgumentParser((ScriptArguments, PPOConfig, ModelConfig)) script_args, training_args, model_args parser.parse_args_into_dataclasses()将所有参数分为三类ScriptArguments, PPOConfig, ModelConfig 使用HfArgumentParser解析参数的逻辑是这个过程没有读源码可能实际上略有出入但是结果应该是一致的首先使用默认值初始化ScriptArguments, PPOConfig, ModelConfig对应的三个对象然后将执行脚本中传入的参数按名称对应分配给每个对象。 例如在PPO的例子的执行脚本中 python -i examples/scripts/ppo/ppo.py \--dataset_name trl-internal-testing/descriptiveness-sentiment-trl-style \--dataset_train_split descriptiveness \--learning_rate 3e-6 \--output_dir models/minimal/ppo \--per_device_train_batch_size 64 \--gradient_accumulation_steps 1 \--total_episodes 10000 \--model_name_or_path EleutherAI/pythia-1b-deduped \--missing_eos_penalty 1.0会将dataset_name和dataset_train_split的值更新给script_argslearning_rate 和output_dir的值更新给model_args 不妨可以看看script_args, training_args, model_args分别包含了哪些参数 script_args ScriptArguments(dataset_nameNone, dataset_configNone, dataset_train_splittrain, dataset_test_splittest, gradient_checkpointing_use_reentrantFalse, ignore_bias_buffersFalse )training_args即ppo_config PPOConfig(output_dirtrainer_output,overwrite_output_dirFalse,do_trainFalse,do_evalFalse,do_predictFalse,eval_strategyIntervalStrategy.NO: no,prediction_loss_onlyFalse,per_device_train_batch_size8,per_device_eval_batch_size8,per_gpu_train_batch_sizeNone,per_gpu_eval_batch_sizeNone,gradient_accumulation_steps1,eval_accumulation_stepsNone,eval_delay0,torch_empty_cache_stepsNone,learning_rate5e-05,weight_decay0.0,adam_beta10.9,adam_beta20.999,adam_epsilon1e-08,max_grad_norm1.0,num_train_epochs3.0,max_steps-1,lr_scheduler_typeSchedulerType.LINEAR: linear,lr_scheduler_kwargs{},warmup_ratio0.0,warmup_steps0,log_levelpassive,log_level_replicawarning,log_on_each_nodeTrue,logging_dirtrainer_output\\runs\\Jun14_11-14-00_LAPTOP-PJP6MGE1,logging_strategyIntervalStrategy.STEPS: steps,logging_first_stepFalse,logging_steps500,logging_nan_inf_filterTrue,save_strategySaveStrategy.STEPS: steps,save_steps500,save_total_limitNone,save_safetensorsTrue,save_on_each_nodeFalse,save_only_modelFalse,restore_callback_states_from_checkpointFalse,no_cudaFalse,use_cpuFalse,use_mps_deviceFalse,seed42,data_seedNone,jit_mode_evalFalse,use_ipexFalse,bf16False,fp16False,fp16_opt_levelO1,half_precision_backendauto,bf16_full_evalFalse,fp16_full_evalFalse,tf32None,local_rank0,ddp_backendNone,tpu_num_coresNone,tpu_metrics_debugFalse,debug[],dataloader_drop_lastFalse,eval_stepsNone,dataloader_num_workers0,dataloader_prefetch_factorNone,past_index-1,run_nametrainer_output,disable_tqdmFalse,remove_unused_columnsTrue,label_namesNone,load_best_model_at_endFalse,metric_for_best_modelNone,greater_is_betterNone,ignore_data_skipFalse,fsdp[],fsdp_min_num_params0,fsdp_config{min_num_params: 0,xla: False,xla_fsdp_v2: False,xla_fsdp_grad_ckpt: False},fsdp_transformer_layer_cls_to_wrapNone,accelerator_configAcceleratorConfig(split_batchesFalse,dispatch_batchesNone,even_batchesTrue,use_seedable_samplerTrue,non_blockingFalse,gradient_accumulation_kwargsNone,use_configured_stateFalse),deepspeedNone,label_smoothing_factor0.0,optimOptimizerNames.ADAMW_TORCH: adamw_torch,optim_argsNone,adafactorFalse,group_by_lengthFalse,length_column_namelength,report_to[wandb],ddp_find_unused_parametersNone,ddp_bucket_cap_mbNone,ddp_broadcast_buffersNone,dataloader_pin_memoryTrue,dataloader_persistent_workersFalse,skip_memory_metricsTrue,use_legacy_prediction_loopFalse,push_to_hubFalse,resume_from_checkpointNone,hub_model_idNone,hub_strategyHubStrategy.EVERY_SAVE: every_save,hub_tokenNone,hub_private_repoNone,hub_always_pushFalse,gradient_checkpointingFalse,gradient_checkpointing_kwargsNone,include_inputs_for_metricsFalse,include_for_metrics[],eval_do_concat_batchesTrue,fp16_backendauto,push_to_hub_model_idNone,push_to_hub_organizationNone,push_to_hub_tokenNone,mp_parameters,auto_find_batch_sizeFalse,full_determinismFalse,torchdynamoNone,ray_scopelast,ddp_timeout1800,torch_compileFalse,torch_compile_backendNone,torch_compile_modeNone,include_tokens_per_secondFalse,include_num_input_tokens_seenFalse,neftune_noise_alphaNone,optim_target_modulesNone,batch_eval_metricsFalse,eval_on_startFalse,use_liger_kernelFalse,eval_use_gather_objectFalse,average_tokens_across_devicesFalse,dataset_num_procNone,num_mini_batches1,total_episodesNone,local_rollout_forward_batch_size64,num_sample_generations10,response_length53,stop_tokenNone,stop_token_idNone,temperature0.7,missing_eos_penaltyNone,sft_model_pathEleutherAI/pythia-160m,world_sizeNone,num_total_batchesNone,micro_batch_sizeNone,local_batch_sizeNone,batch_sizeNone,local_mini_batch_sizeNone,mini_batch_sizeNone,exp_nameppo_config,reward_model_pathEleutherAI/pythia-160m,model_adapter_nameNone,ref_adapter_nameNone,num_ppo_epochs4,whiten_rewardsFalse,kl_coef0.05,kl_estimatork1,cliprange0.2,vf_coef0.1,cliprange_value0.2,gamma1.0,lam0.95,ds3_gather_for_generationTrue )model_args ModelConfig(model_name_or_pathNone, model_revisionmain, torch_dtypeNone, trust_remote_codeFalse, attn_implementationNone, use_peftFalse, lora_r16, lora_alpha32, lora_dropout0.05, lora_target_modulesNone, lora_modules_to_saveNone, lora_task_typeCAUSAL_LM, use_rsloraFalse, use_doraFalse, load_in_8bitFalse, load_in_4bitFalse, bnb_4bit_quant_typenf4, use_bnb_nested_quantFalse, )其中 script_args主要控制数据集的参数。特别地dataset_name是required但这个事情比较奇怪 如果是在Linux上使用脚本启动的话是必须传入--dataset_name或者--dataset-name参数的否则会显示 error: the following arguments are required: --dataset_name/--dataset-nametransformers4.51.3, trl0.17.0在Windows系统上即使不传入任何参数也是不会报错的transformers4.52.4, trl0.18.1比Linux上的版本要高一点但应该不是版本问题。 training_args即PPOConfig的对象对应PPOTrainer中的args参数其中包含了非常多与训练参数具体每个训练器及其配置的参数列表可以直接到HuggingFace上的TRL文档查看https://huggingface.co/docs/trl这里例举常用的参数 output_dir模型checkpoint导出的路径,adam_xxxAdam优化器相关的参数似乎TRL里所有Trainer的默认优化器都是Adam不过是可以在PPOTrainer的参数中定义优化器的PPOTrainer的优化器参数optimizers是需要传入两个变量optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]一个是优化器另一个是学习率的规划器。learning_rate学习率,kl_coefKL散度的惩罚系数越大相当于策略更新幅度越小越小策略就可能更新幅度越快, model_args对应的是PPOTrainer的peft_config参数里面也有一些和模型初始化相关的参数包括量化与PEFT相关的参数也是从中提取得到的比如在base_pipeline脚本中 ... # Trainer Pipeline # param name: [Str] e.g. SFT, PPO, DPO, GRPO # param data_processor: Function object prepared for dataset.map(data_processor) # param trainer_config: [Dict, peft.XXXConfig] including keyword arguments, e.g. # param model_config: [Dict, peft.ModelConfig] including keyword arguments, e.g. # param script_arguments: [Dict, peft.ScriptArguments] including keyword arguments, e.g. dataset_name, dataset_train_split, dataset_test_split # param config_kwargs: [Dict] keyword arguments for updating TRL-Config, ScriptArguments, ModelConfig # - keyword arguments for TRLConfig: e.g. output_dir, adam_xxx, learning_rate, kl_coef, push_to_hub # - keyword arguments for ScriptArguments: e.g. output_dir, adam_xxx, learning_rate, kl_coef, push_to_hub # - keyword arguments for ModelConfig: e.g. model_name_or_path, torch_dtype, trust_remote_code, use_peft, lora_xxx, load_in_4bit, bnb_4bit_compute_dtype, bnb_4bit_quant_type # param trainer_kwargs: [Dict] keyword arguments for updating TRL-Trainer # - keyword arguments for all Trainers: e.g. data_collator, callbacks # - keyword arguments for SFTTrainer: e.g. compute_loss_func, compute_metrics # - keyword arguments for PPOTrainer: e.g. ref_model[required], reward_model[required], value_model[required] # - keyword arguments for DPOTrainer: e.g. ref_model # - keyword arguments for GRPOTrainer: e.g. reward_funcs[required] def base_pipeline(name, data_processor, config_kwargs, trainer_kwargs):# 1 ConfigurationTRLConfig, TRLTrainer eval(f{name}Config), eval(f{name}Trainer)parser HfArgumentParser((ScriptArguments, TRLConfig, ModelConfig))script_arguments, trainer_config, model_config parser.parse_args_into_dataclasses()script_arguments update_trl_config(script_arguments, **config_kwargs)trainer_config update_trl_config(trainer_config, **config_kwargs)model_config update_trl_config(model_config, **config_kwargs)peft_config get_peft_config(model_config)quantization_config get_quantization_config(model_config)...peft_config和quantization_config都是可以现成的函数生成的它们被定义在trl项目根目录下的trainer/utils.py中 def get_quantization_config(model_args: ModelConfig) - Optional[BitsAndBytesConfig]:if model_args.load_in_4bit:quantization_config BitsAndBytesConfig(load_in_4bitTrue,bnb_4bit_compute_dtypemodel_args.torch_dtype, # For consistency with model weights, we use the same value as torch_dtypebnb_4bit_quant_typemodel_args.bnb_4bit_quant_type,bnb_4bit_use_double_quantmodel_args.use_bnb_nested_quant,bnb_4bit_quant_storagemodel_args.torch_dtype,)elif model_args.load_in_8bit:quantization_config BitsAndBytesConfig(load_in_8bitTrue,)else:quantization_config Nonereturn quantization_configdef get_peft_config(model_args: ModelConfig) - Optional[PeftConfig]:if model_args.use_peft is False:return Noneif not is_peft_available():raise ValueError(You need to have PEFT library installed in your environment, make sure to install peft. Make sure to run pip install -U peft.)peft_config LoraConfig(task_typemodel_args.lora_task_type,rmodel_args.lora_r,target_modulesmodel_args.lora_target_modules,lora_alphamodel_args.lora_alpha,lora_dropoutmodel_args.lora_dropout,biasnone,use_rsloramodel_args.use_rslora,use_doramodel_args.use_dora,modules_to_savemodel_args.lora_modules_to_save,)return peft_config从上面的源码可以看出如果model_config.use_peftFalse则默认不启用PEFT得到的peft_config也就是None同理是否采用量化取决于load_in_4bit和load_in_8bit是否至少有一个是True 1.2 加载分词器与模型tokenizer model 回到base_pipeline中对应的部分 ...# 2 Load models and tokenizerlogging.info(Load models and tokenizer ...)logging.info(f - Model: {model_config.model_name_or_path})tokenizer AutoTokenizer.from_pretrained(model_config.model_name_or_path)if not pad_token in tokenizer.special_tokens_map:tokenizer.add_special_tokens({pad_token: [PAD]})if tokenizer.chat_template is None:tokenizer.chat_template SIMPLE_CHAT_TEMPLATEmodel AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path,device_map auto,trust_remote_code model_config.trust_remote_code,quantization_config quantization_config,)if peft_config is not None:logging.info(Prepare model for PEFT ...)model.config.pretraining_tp 1model.config.use_cache Falsemodel.gradient_checkpointing_enable()# If prepare_model_for_kbit_training is ignored, and gradient_checkpointing True (for GPU memory saving)# Then you need set model.enable_input_require_grads() yourself# model prepare_model_for_kbit_training(model)model.enable_input_require_grads()model get_peft_model(model, peft_config)...这里主要是几个零碎的注意点 分词器的词汇表中一定要包含pad_token具体可以通过tokenizer.special_tokens_map否则在后面trainer.train()时会报错提示你设置pad_token关于tokenizer.chat_template的问题这个到接下来的数据处理部分会详细说明针对对话类的任务一般是会采用chat_template来处理数据集的因此可以用trl.trainer.utils中提供的SIMPLE_CHAT_TEMPLATE作为缺省模型可以通过量化加载即quantization_config不为None时这样会节约内存此时一般都会做如下的设置如果不是量化加载的话可能加上这些配置也没什么关系所以我就这么写了 这里的pretraining_tp是张量并行的意思数值越大就会并行的数量越多当然只有多卡才能并行单卡只能设为1use_cache应该是指大概做model.generate()生成时是否进行kv_cache缓存model.enable_input_require_grads()也是节约内存的方法我记得是缓存梯度以达到更快的反向传播运算的作用。 ...model.config.pretraining_tp 1model.config.use_cache Falsemodel.gradient_checkpointing_enable()# If prepare_model_for_kbit_training is ignored, and gradient_checkpointing True (for GPU memory saving)# Then you need set model.enable_input_require_grads() yourself# model prepare_model_for_kbit_training(model)model.enable_input_require_grads()...然后对于PPOTrainer而言非常特殊它的构造参数还需要reward_model, value_model, ref_model三个参数因此在base_pipeline中做了额外处理 if name PPO:logging.info(PPO load reward value and reference models ...)# PPO is special! It needs more components!logging.info(f - Reward model: {trainer_config.reward_model_path})reward_model AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code model_config.trust_remote_code,num_labels 1,)value_model AutoModelForSequenceClassification.from_pretrained(trainer_config.reward_model_path,trust_remote_code model_config.trust_remote_code,num_labels 1,)logging.info( - Copy reference model ...)ref_model deepcopy(model)# ref_model model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())trainer_kwargs[reward_model] reward_modeltrainer_kwargs[value_model] value_modeltrainer_kwargs[ref_model] ref_modellogging.info( - Done!)这里注意一下 一般初始化ref_model和目标模型是完全一样的这里测试了两种方法一种是直接ref_model deepcopy(model)这确实是可行的不会发生unpickled之类的问题另一种方法则是重构一个和目标模型结构相同的空模型然后加载状态字典 # ref_model model.__class__(model.config)# ref_model.load_state_dict(model.state_dict())这个方法在use_peftFalse时是可行的但是因为这里可能是一个peft_model因此model.__class__的参数还需要一个peft_config直接这样写是会报错的使用HfArgumentParser加载PPOConfig参数时ppo_config.reward_model_path是有默认值的可以看上面的training_args即EleutherAI/pythia-160m这个奖励模型必须是可以被AutoModelForSequenceClassification类型加载的value_model通常设置为和reward_model相同至少在官方的PPO示例中是这样的其实也可以解释我的理解是 A ( s , a ) Q ( s , a ) − V ( s ) A(s,a)Q(s,a)-V(s) A(s,a)Q(s,a)−V(s) 中这里 V ( s ) V(s) V(s)即value_model Q ( s , a ) r ( s , a ) γ V ( s ′ ) Q(s,a)r(s,a)\gamma V(s) Q(s,a)r(s,a)γV(s′)中的 r ( s , a ) r(s,a) r(s,a)即奖励函数在CAUSAL_LM的语境下其实就是句子 s s s加了一个单词 a a a后的奖励本质上都是对一句话进行评分。虽然说得通但是感觉也挺奇怪的。 1.3 数据集分割与字段问题dataset 目前主流的数据集加载都是直接用HuggingFace的datasets.load_dataset方法加载标准的数据格式就是每个样本的格式为{“column_1”: data_1, column_2: data_2}整个数据集类似jsonl的格式。 针对目前已知的几个训练器它们对数据集的格式要求大概是这样的 SFTTrainer字段要包含prompt和completion注意训练时只会在completion部分计算损失prompt部分是不会计算损失的 但是其实我发现如果数据集中只有text或者inputtarget这种字段时也是可行的需要仔细读源码才行。 DPOTrainer这个就很简单一般就是promptchosenrejected但是在官方示例使用的数据集trl-lib/ultrafeedback_binarized中字段却是chosen, rejected, score_chosen, score_rejected也就是说其实prompt可能并非必须本来也确实可以直接用空字符串替代然后数据集中如果没有回答进行评分的话可能还有一个默认的评分机制在训练器里面。 GRPOTrainer这个我也是根据官方示例来的我看到它使用的数据集是trl-lib/tldr这是一个典型的只有prompt和completion两个字段的数据集 PPOTrainer这个最为特殊之前DeepSeek一直写不对的原因也在于此 在官方给出的PPO运行脚本ppo.py中 def prepare_dataset(dataset, tokenizer):pre-tokenize the dataset before training; only collate during trainingdef tokenize(element):outputs tokenizer(element[dataset_text_field],paddingFalse,)return {input_ids: outputs[input_ids]}return dataset.map(tokenize,batchedTrue,remove_columnsdataset.column_names,num_proctraining_args.dataset_num_proc,)注意到这里是直接处理成分词后的input_ids格式 但是我本来以为PPO和GRPO是类似的因此理论上在数据集的格式要求上应该也差不了太多因此我觉得可能promptcompletion也是可行的但是测试下来会报错报错提示是一定要求是带input_ids字段的。 因此我单独给PPO写了一个数据处理的data_processor而其余的都是很简单的 # param name: [Str] e.g. SFT, PPO, DPO, GRPO def generate_simple_data_processor(name, **kwargs):if name in [SFT, GRPO]:def _data_processor(_data):return {prompt: _data[prompt], completion: _data[completion]}elif name PPO:tokenizer kwargs.get(tokenizer)def _data_processor(_data):outputs tokenizer(_data[prompt] _data[completion], padding False)return {input_ids: outputs[input_ids]}elif name DPO:def _data_processor(_data):return {prompt: _data[prompt], chosen: _data[chosen], rejected: _data[rejected]}else:raise NotImplementedError(name)return _data_processor当然一些写法也会使用apply_chat_template这通常是用于交互式对话任务都是可行的 def _data_processor(_data):_message [{role: system, content: You are an AI assistant developped by CY},{role: user, content: _data[dataset_input_column]},{role: assistant, content: _data[dataset_target_column]},]_prompt tokenizer.apply_chat_template(_message, tokenizeFalse)return {text: _prompt}然后一个小细节是在用dataset.map进行数据处理时一般会设置参数emove_columnsdataset.column_names以丢弃原先不必要的字段防止Trainer在运行时错用了其他的字段进行训练。 1.4 训练与保存checkpoint base_pipeline中最后一部分训练完然后保存模型顺利地话就完整地跑通 # 4 Train modellogging.info(Trainer starts ...)trainer TRLTrainer(model model,args trainer_config,train_dataset train_dataset,eval_dataset eval_dataset,processing_class tokenizer,peft_config peft_config,**trainer_kwargs)trainer.train()logging.info( - Trainer finishes!)# 5 Save modelif trainer_config.push_to_hub:logging.info(f - Push checkpoints to {trainer_config.organization}/{trainer_config.push_to_hub_model_id})trainer.push_to_hub()logging.info(fSave model to {trainer_config.output_dir})trainer.save_model(trainer_config.output_dir)这里值得注意的几个问题 在trainer_config即SFTConfig, PPOConfig, DPOConfig, GRPOConfig之类的对象中有一个参数trainer_config.report_to这个不设置的话默认是会上传到WB的网络不支持访问的话是会在训练到checkpoint的时候发生网络错误的因此一般会设置成none或者也可以用tensorboard 不过就训练绘图的话训练结束到trainer_config.output_dir中找到checkpoint-xxx文件夹里面会有trainer_state.json文件然后自己根据里面的数据绘图即可例如# Plot dynamics of TRL trainer state def plot_trl_dynamics(trainer_state_path):with open(trainer_state_path, r, encodingutf8) as f:data json.load(f)log_history data[log_history]steps [entry[step] for entry in log_history]episodes [entry[episode] for entry in log_history]epochs [entry[epoch] for entry in log_history]policy_loss [entry[loss/policy_avg] for entry in log_history]value_loss [entry[loss/value_avg] for entry in log_history]lrs [entry[lr] for entry in log_history]entropys [entry[objective/entropy] for entry in log_history]kls [entry[objective/kl] for entry in log_history]non_score_rewards [entry[objective/non_score_reward] for entry in log_history]rlhf_rewards [entry[objective/rlhf_reward] for entry in log_history]scores [entry[objective/scores] for entry in log_history]plt.figure(figsize(8, 8))ax_1 plt.subplot(2, 2, 1)ax_2 plt.subplot(4, 2, 2)ax_3 plt.subplot(4, 2, 4)ax_4 plt.subplot(2, 2, 3)ax_5 plt.subplot(2, 2, 4)ax_1.plot(steps, policy_loss, labelPolicy Loss)ax_1.plot(steps, value_loss, labelValue Loss, linestyle--)ax_1.set_xlabel(Step), ax_1.set_ylabel(Loss), ax_1.legend()ax_1.set_title(Policy and Value Loss)# ------------------------------------------------------------------ax_2.plot(steps, kls, labelobjective/kl)ax_2.set_xlabel(Step), ax_2.set_ylabel(KL), ax_2.legend()ax_2.set_title(KL Curve)# ------------------------------------------------------------------ax_3.plot(steps, entropys, labelobjective/entropy)ax_3.set_xlabel(Step), ax_3.set_ylabel(Entropy), ax_3.legend()ax_3.set_title(Entropy Curve)# ------------------------------------------------------------------ax_4.plot(steps, lrs, labelLearning Rate)ax_4.set_xlabel(Step), ax_4.set_ylabel(Learning Rate), ax_4.legend()ax_4.set_title(Learning Rate Curve)# ------------------------------------------------------------------ax_5.plot(steps, non_score_rewards, labelobjective/non_score_reward, linestyle--)ax_5.plot(steps, rlhf_rewards, labelobjective/rlhf_reward, linestyle--)ax_5.plot(steps, scores, labelobjective/scores)ax_5.set_xlabel(Step), ax_5.set_ylabel(Score/Reward), ax_5.legend()ax_5.set_title(Reward and Score)plt.show()print(steps)print(- * 64)print(episodes)print(- * 64)print(epochs)关于WB的使用如果网络支持使用wandb确实也是不错的选择 登录 wandb 方法 1交互式登录推荐 运行以下代码会弹出浏览器或要求输入 API Key import wandb wandb.login()然后按提示操作即可。 方法 2直接提供 API Key import wandb wandb.login(keyyour_api_key_here) # 替换成你的 wandb API Key⚠️ 注意不要直接在代码里硬编码 API Key特别是公开代码时建议使用环境变量或配置文件。 方法 3通过环境变量设置适合脚本/服务器 export WANDB_API_KEYyour_api_key_here然后在 Python 中直接 wandb.init() 即可。 步骤 3在 Trainer 中启用 wandb 方式 1自动记录默认 HuggingFace Trainer 默认会使用 wandb如果已安装并登录 from transformers import Trainer, TrainingArguments training_args TrainingArguments(output_dir./results,report_towandb, # 明确指定使用 wandb默认已经是这样logging_dir./logs,# 其他训练参数... ) trainer Trainer(modelmodel,argstraining_args,train_datasettrain_dataset,eval_dataseteval_dataset, ) trainer.train() # 自动记录到 wandb方式 2手动初始化 wandb更灵活控制 import wandb # 初始化 wandb 并设置项目名称 wandb.init(projectmy-ppo-project, nameexperiment-1) # 然后正常训练 trainer.train() # 训练结束后可选 wandb.finish()这样你可以自定义 wandb 的配置如 project、name、config 等。 额外配置可选 1. 记录超参数 config {learning_rate: 2e-5,batch_size: 32,epochs: 3, } wandb.init(projectmy-project, configconfig)训练时这些参数会显示在 wandb 面板中。 2. 自定义记录指标 # 在训练循环中记录自定义指标 wandb.log({reward: current_reward, loss: current_loss})3. 恢复中断的训练 wandb.init(id之前运行的ID, resumemust)训练开始后终端会显示类似Tracking run with wandb, view at https://wandb.ai/your-username/your-project/runs/xxx打开链接即可查看实时训练数据。 1.5 单元测试 单元测试脚本trainer_pipelines.py # -*- coding: utf8 -*- # author: caoyang # email: caoyangstu.sufe.edu.cnimport os import logging from src.pipelines.trainer import base_pipeline, sft_pipeline, ppo_pipeline, dpo_pipeline, grpo_pipelinemodel_home /nfsshare/home/caoyang/resource/model dataset_home /nfsshare/home/caoyang/resource/dataset model_names [Qwen/Qwen2.5-0.5B-Instruct,EleutherAI/pythia-1b-deduped,EleutherAI/pythia-160m, ]dataset_names [trl-lib/tldr, # train[prompt, completion] validation[prompt, completion] test[prompt, completion]trl-lib/ultrafeedback_binarized, # train[chosen, rejected, score_chosen, score_rejected] test[chosen, rejected, score_chosen, score_rejected]trl-internal-testing/descriptiveness-sentiment-trl-style, # sentiment[prompt, chosen, rejected] descriptiveness[prompt, chosen, rejected]YeungNLP/firefly-train-1.1M, # train[input, target] ]def sft_pipeline_test():logging.info(SFT unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[0])data_processor Noneconfig_kwargs {output_dir: f./temp/sft{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[500:600],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {}sft_pipeline(data_processor, config_kwargs, trainer_kwargs)def ppo_pipeline_test():logging.info(PPO unittest ...)model_name_or_path os.path.join(model_home, model_names[1])EleutherAI/pythia-1b-dedupedGPTNeoXForCausalLM((gpt_neox): GPTNeoXModel((embed_in): Embedding(50304, 2048)(emb_dropout): Dropout(p0.0, inplaceFalse)(layers): ModuleList((0-15): 16 x GPTNeoXLayer((input_layernorm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(post_attention_layernorm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(post_attention_dropout): Dropout(p0.0, inplaceFalse)(post_mlp_dropout): Dropout(p0.0, inplaceFalse)(attention): GPTNeoXAttention((query_key_value): Linear(in_features2048, out_features6144, biasTrue)(dense): Linear(in_features2048, out_features2048, biasTrue))(mlp): GPTNeoXMLP((dense_h_to_4h): Linear(in_features2048, out_features8192, biasTrue)(dense_4h_to_h): Linear(in_features8192, out_features2048, biasTrue)(act): GELUActivation())))(final_layer_norm): LayerNorm((2048,), eps1e-05, elementwise_affineTrue)(rotary_emb): GPTNeoXRotaryEmbedding())(embed_out): Linear(in_features2048, out_features50304, biasFalse))dataset_name os.path.join(dataset_home, dataset_names[0])reward_model_path os.path.join(model_home, model_names[2])data_processor Noneconfig_kwargs {output_dir: f./temp/ppo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,reward_model_path: reward_model_path,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[:100],use_peft: True,report_to: none,lora_target_modules: [query_key_value],}trainer_kwargs {}ppo_pipeline(data_processor, config_kwargs, trainer_kwargs)def dpo_pipeline_test():logging.info(DPO unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[2])data_processor Noneconfig_kwargs {output_dir: f./temp/dpo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: descriptiveness[:500],dataset_test_split: descriptiveness[500:600],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {}dpo_pipeline(data_processor, config_kwargs, trainer_kwargs)def grpo_pipeline_test():logging.info(GRPO unittest ...)model_name_or_path os.path.join(model_home, model_names[0])dataset_name os.path.join(dataset_home, dataset_names[0])data_processor Nonedef reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]config_kwargs {output_dir: f./temp/grpo{model_name_or_path.split(/)[-1]}{dataset_name.split(/)[-1]},model_name_or_path: model_name_or_path,dataset_name: dataset_name,trust_remote_code: True,dataset_train_split: train[:500],dataset_test_split: validation[:100],use_peft: True,report_to: none,lora_target_modules: [q_proj, k_proj, v_proj]}trainer_kwargs {reward_funcs: reward_funcs,}grpo_pipeline(data_processor, config_kwargs, trainer_kwargs)这里主要就是grpo_pipeline_test中的reward_funcs其他几个测试函数都没什么好说的很好理解 GRPOTrainer中有required参数reward_funcs这个在TRL文档中给出的例子是 def reward_funcs(completions, **kwargs):return [float(len(set(completion))) for completion in completions]这个奖励的意思是会给长回答更高的得分。 注意这里的**kwargs是不可以省略的因为GRPOTrainer调用reward_funcs会传入其他参数比如prompts或者ground_truths在TRL的GRPO文档中有很多例子。 2 一些其他的问题 这部分摘取一些和DeepSeek讨论的问题 2.1 DataProcessor与DataCollator 关于TRL的Trainer的一个问题因为Trainer是可以传入data_collator参数的但是一般来说我们都会在向Trainer传入数据集如train_dataset之前先写好一个函数用于预处理数据集比如train_dataset train.dataset.map(prompt_formatter)这样我的理解其实这里prompt_formatter何data_collator的作用是一样的我觉得它们处理数据的逻辑也是一样就是传入一个样本然后做数据处理因此如果在Trainer外面已经做了train_dataset.map(prompt_formatter)的操作就不需要再传入data_collator了。反之也是一样我可以传入data_collator而无需进行train_dataset.map(prompt_formatter)的操作 以下回答由AI生成 你的理解基本正确但有一些细节需要注意。prompt_formatter在dataset.map中使用和data_collator确实都涉及数据处理但它们的作用时机和目的有所不同。下面我来详细解释它们的区别以及何时使用它们 2.1.1 dataset.map(prompt_formatter): 作用时机: 在数据加载时预处理阶段即在训练开始之前对数据集中的每个样本进行处理。典型用途: 将原始文本转换为模型需要的输入格式例如添加特殊token、构建prompt-completion结构。对数据进行过滤通过remove_columns删除不需要的字段。执行一次性的、确定性的转换例如tokenization如果你不启用动态padding。 特点: 处理结果是静态的处理后的数据会被缓存除非设置load_from_cache_fileFalse。适合所有样本共享的逻辑例如统一的prompt模板。 2.1.2 data_collator: 作用时机: 在训练过程中动态批处理阶段即在每个batch生成时对样本进行实时处理。典型用途: 动态padding将同一batch内的样本padding到相同长度。随机数据增强例如随机masking适用于MLM任务。构造特殊输入例如为RLHF构建奖励模型的输入格式。 特点: 处理是动态的每次batch采样时重新执行。适合需要随机性或依赖batch内其他样本的操作例如padding长度取决于当前batch。 2.1.3 何时选择哪种方式 只用dataset.map: 如果你的数据处理是确定性的、不需要动态调整例如固定的prompt模板且不需要batch级别的操作如padding可以只使用dataset.map。但注意如果涉及变长序列仍需data_collator处理padding除非你已手动填充到固定长度。 只用data_collator: 如果数据处理需要动态性例如随机masking或依赖batch上下文如padding可以只用data_collator。但注意这样可能导致训练时重复处理数据例如每次epoch重复tokenization影响效率。 联合使用: 常见做法是 用dataset.map处理静态逻辑例如tokenization、添加固定prompt。用data_collator处理动态逻辑例如padding、随机增强。 例如def prompt_formatter(example):example[input_ids] tokenizer(example[text]).input_idsreturn example # 不padding保留原始长度train_dataset raw_dataset.map(prompt_formatter, batchedFalse)data_collator DataCollatorForLanguageModeling(tokenizer, mlmTrue) # 动态padding和masking如果prompt_formatter已经完成了所有必要处理包括padding到固定长度且不需要任何动态逻辑确实可以跳过data_collator或传入default_data_collator它仅做简单的batch堆叠。 但如果需要动态操作如padding即使有dataset.map仍需data_collator。 二者并非互斥而是互补的。dataset.map用于离线预处理data_collator用于在线批处理。 选择取决于你的数据处理需求是否依赖动态上下文或随机性。 2.2 关于PartialState accelerate库里有一个方法PartialState().local_main_process_first()我看到这个方法会用来加载数据集这个看起来像是并行加载数据集的意思你可以解释一下这个函数是用来做什么的吗 注这个方法本来是在PPO官方示例中使用的 PartialState().local_main_process_first() 是 Hugging Face accelerate 库中用于多进程并行处理的一个关键方法尤其在分布式训练或推理时控制数据加载、文件处理等操作的执行顺序。它的核心作用是确保某些操作如数据集加载或文件预处理只在**主进程main process**上优先执行避免多进程竞争或重复操作。 2.2.1 核心功能 local_main_process_first() 返回一个上下文管理器context manager其行为如下 主进程local main process会立即执行上下文内的代码块。其他进程非主进程会等待主进程完成后再跳过或执行相同操作取决于具体逻辑。 目的避免多进程同时操作共享资源如文件读取、数据集缓存导致的冲突或冗余计算。 2.2.2 典型使用场景 (1) 加载数据集时避免重复缓存 在分布式训练中如果多个进程同时尝试缓存数据集例如通过 datasets 库可能会导致文件写入冲突。通过 local_main_process_first() 可以确保只有主进程生成缓存其他进程直接复用 from accelerate import PartialStatestate PartialState() # 初始化分布式状态# 只有主进程先缓存数据集其他进程等待后直接加载缓存 with state.local_main_process_first():dataset load_dataset(imdb) # 假设这是耗时操作(2) 文件预处理或下载 如果预处理如分词、数据转换只需要执行一次 with state.local_main_process_first():tokenized_data dataset.map(tokenize_function, batchedTrue) # 主进程先处理 # 其他进程会自动等待主进程完成然后直接使用处理后的数据2.2.3 参数与底层机制 PartialState是 accelerate 中管理分布式状态的类自动检测当前进程的分布式环境如多GPU、TPU等。local_main_process指当前节点机器上的主进程通常是全局主进程的一个子集在多机训练时每台机器有自己的 local_main_process。同步机制底层通过 torch.distributed 或类似库实现进程间同步确保非主进程等待主进程完成。 2.2.4 与类似方法的区别 方法/属性作用范围用途local_main_process_first()当前节点的主进程优先避免文件/缓存操作的竞争main_process_first()全局主进程优先跨所有节点适用于全局只需执行一次的操作on_main_process()仅在主进程执行完全跳过其他进程如日志记录 2.2.5 完整示例分布式训练中的数据加载 from accelerate import PartialState from datasets import load_dataset# 初始化分布式状态自动检测环境 state PartialState()# 只有主进程优先加载并缓存数据其他进程等待 with state.local_main_process_first():dataset load_dataset(glue, mrpc, splittrain)tokenized_data dataset.map(lambda x: tokenizer(x[sentence]), batchedTrue)# 所有进程继续同步训练 for batch in tokenized_data:outputs model(**batch)...2.2.6 注意事项 缓存路径一致性确保所有进程能访问相同的缓存路径如共享存储。性能影响主进程的预处理时间会被计入总耗时但避免了重复操作的开销。错误处理如果主进程预处理失败其他进程会一直等待需添加超时或错误处理逻辑。 2.2.7 总结 何时使用在分布式环境中任何需要单例执行的操作如数据缓存、文件下载、预处理都应使用 local_main_process_first()。优势减少资源竞争避免冗余计算确保分布式任务的高效协同。 如果需要进一步了解 accelerate 的分布式策略如多机训练可以继续探讨 2.3 DPO和PPO的reference_model的区别 这个是阅读TRL文档发现的一个问题DPOTrainer的ref_model参数是有默认值None的但是PPOTrainer的ref_model参数没有默认值。这意味着必须给PPOTrainer传入ref_model参数尽管一般都是跟目标模型一样。 你提出了一个很好的问题DPODirect Preference Optimization虽然确实不属于传统强化学习的范畴但它仍然保留了reference model这是出于特定的设计考虑。下面我详细解释原因 2.3.1 DPO的核心思想与Reference Model的作用 DPO通过直接优化偏好数据来规避强化学习但其目标函数源自Bradley-Terry模型隐式地包含了一个KL散度约束 L DPO − E [ log ⁡ σ ( β log ⁡ π θ ( y w ∣ x ) π ref ( y w ∣ x ) − β log ⁡ π θ ( y l ∣ x ) π ref ( y l ∣ x ) ) ] \mathcal{L}_{\text{DPO}} -\mathbb{E} \left[ \log \sigma\left( \beta \log \frac{\pi_\theta(y_w|x)}{\pi_{\text{ref}}(y_w|x)} - \beta \log \frac{\pi_\theta(y_l|x)}{\pi_{\text{ref}}(y_l|x)} \right) \right] LDPO​−E[logσ(βlogπref​(yw​∣x)πθ​(yw​∣x)​−βlogπref​(yl​∣x)πθ​(yl​∣x)​)] 其中 π θ \pi_\theta πθ​ 是待训练的策略当前模型。 π ref \pi_{\text{ref}} πref​ 是reference model通常初始化为预训练模型例如SFT模型。 β \beta β 是控制偏离参考模型程度的超参数。 Reference Model的作用 防止过度偏离初始策略KL散度约束通过reference model体现避免模型为迎合偏好数据而退化例如生成极端或无意义的输出。隐式正则化类似于PPO中的KL惩罚但DPO通过直接在损失函数中对比当前模型和reference模型的输出概率实现。 2.3.2 为什么DPO需要Reference Model尽管它不是RL 数学推导的必然性DPO的公式是从逆向强化学习的目标推导而来的其中reference model代表初始策略即RL中的 π old \pi_{\text{old}} πold​。去掉reference model会导致目标函数失去约束模型可能过拟合偏好数据。实践中的稳定性实验表明去掉reference model后模型容易过度优化偏好信号例如始终输出短答案或重复模式。 2.3.3 TRL的DPOTrainer中的ref_model 在HuggingFace TRL库中DPOTrainer确实需要传入ref_model参数原因如下 计算概率比DPO需要比较当前模型和reference模型对正/负样本的概率即(\log \frac{\pi_\theta(y|x)}{\pi_{\text{ref}}(y|x)}))。冻结参考模型ref_model在训练过程中是冻结的不更新仅作为基准。内存优化实际实现中ref_model可能与当前模型共享部分结构例如共享底层Transformer但通过梯度截断避免更新。 2.3.4 DPO vs PPO的Reference Model 特性PPO (RL)DPO (Non-RL)Reference Model旧策略 π old \pi_{\text{old}} πold​预训练模型 π ref \pi_{\text{ref}} πref​更新方式定期同步on-policy要求完全冻结作用重要性采样校正 KL约束隐式KL约束 防止过拟合 2.3.5 如果没有Reference Model会怎样 假设令 π ref \pi_{\text{ref}} πref​为均匀分布则DPO退化为直接最大化偏好对的似然可能导致 模型过度自信overconfident。生成内容多样性下降模式坍塌。对噪声偏好数据过拟合。 DPO虽然跳过了显式的强化学习框架但通过保留reference model继承了RL中“约束策略更新”的核心思想。这是其稳定性和性能的关键设计也是它区别于纯监督学习如直接微调的重要特征。 PS我感觉这个回答没有解释我的问题答非所问了
http://www.dnsts.com.cn/news/64381.html

相关文章:

  • 免费的seo教程专注于上海seo做网站建设
  • 苏州建站公司哪个好帮建网站
  • 怎么利用爬虫技术 提高网站排名怎样网站建设与管理
  • 做化工外贸需要那些网站wordpress mkv格式
  • 购物网站优化的建议wordpress调用图片路径
  • 做网站怎么做付费推广上海建设摩托车官网
  • 有口碑的常州网站优化苏州园区已经烂掉了
  • 移动网站mip5000做网站
  • 怎样进入当地建设局网站邢台市建设局网站
  • 网站建设需要哪些费用支出简单大气网站欣赏
  • 网站备案后网站建设中最基本的决策之一是
  • 如何拷贝服务器里面网站做备份吴江做企业网站
  • 如何建一个论坛网站高端品牌羽绒服前十名
  • 平面设计免费素材网站苏州代理记账
  • 佛山做网站建设公司wordpress关闭手机访问不了
  • 自己搭建环境建设网站wordpress网站和微信公众号
  • 快速建站服务器山西晋中网站建设
  • 做海报素材网站推荐网站制作教程
  • 需要网站建设的是哪一类人网站建设基本
  • 宁波网站制作公司哪家好网页版qq中心登录入口
  • 网站效果案例通许画册设计网站
  • 玉溪企业网站建设php网站模板制作软件
  • 2016做网站还赚钱吗网络营销方法有哪些举例
  • 建个营销型网站多少钱河北手动网站建设商店
  • 关于进一步加强门户网站建设粒子特效网站
  • 老干部活动中心网站建设方案wordpress 模板 html5
  • 无锡网站搜索优化品牌推广的步骤和技巧
  • 建设网站赚钱的方法相城区公司网站建设
  • 免费企业cms建站系统中国建设银行手机银行官网
  • 商丘柘城做网站php网站是什么数据库文件