CARVIEW |
Navigation Menu
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Releases: huggingface/trl
v0.19.1
Compare
What's Changed
- fix grpo generation_kwargs by @ahatamiz in #3634
- Make sure chat template isn't lost when truncating prompt. by @pramodith in #3651
- Add paranthesis to correct the check. by @pramodith in #3658
- [SFT] drop attention_mask if we have position ids for fa2 by @kashif in #3673
- Support datasets 4 by @lhoestq in #3688
- 📣 Use explicit version for checking datasets version by @qgallouedec in #3702
- Fix non-serializable torch.dtype bug in VLLM weight sync by @CarlosArguilar in #3690
- ✂️ [BUG when vllm and prompt_truncation are used]: Strip out pad tokens in truncated prompt text by @pramodith in #3698
New Contributors
- @ahatamiz made their first contribution in #3634
- @lhoestq made their first contribution in #3688
- @CarlosArguilar made their first contribution in #3690
Full Changelog: v0.19.0...v0.19.1
Assets 2
v0.19.0
5b3ea9d
Compare
Breaking and major changes
🧰 [SFT] Tool support
SFTTrainer
now supports training with tools! You just have to add a column tools
to your dataset, which contains a list of tool definitions as json schemas. The tools will be automatically registered and can be used in the training process.
from datasets import Dataset
from transformers.utils import get_json_schema
from trl import SFTTrainer
# Fictitious functions to simulate tool calls
def start_timer(duration: int) -> int:
"""
Starts a timer for the specified duration in seconds.
Args:
duration: Duration in seconds to set the timer for.
Returns:
The duration set for the timer.
"""
return duration
def create_reminder(time: str, note: str) -> str:
"""
Creates a reminder for the specified time and note.
Args:
time: The time for the reminder.
note: The note for the reminder.
Returns:
A confirmation message indicating that the reminder has been set.
"""
return "I'll remind you to call mom at 7 PM."
# Define the JSON schemas for the tools
start_timer = get_json_schema(start_timer)
create_reminder = get_json_schema(create_reminder)
dataset = Dataset.from_dict({
"messages": [
[
{"role": "user", "content": "Set a timer for 10 minutes."},
{"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "start_timer", "arguments": {"duration": 600}}}]},
{"role": "tool", "name": "start_timer", "content": "600"},
{"role": "assistant", "content": "Timer set for 10 minutes."},
],
...,
],
"tools": [
[start_timer, create_reminder],
...,
]
})
# Initialize the trainer
trainer = SFTTrainer(model="Qwen3-0.6B", train_dataset=dataset)
# Train the model
trainer.train()
by @qgallouedec in #3597
📉 FFD packing
We introduce a new packing method: FFD (First Fit Decreasing) packing. This method is designed to optimize the packing of sequences in a way that more efficiently reduces the size of the training dataset by grouping examples more effectively. Previously, we used a wrapped packing method, which often truncated sequences even when they were not longer than the maximum sequence length. The new FFD packing method avoids unnecessary truncation by grouping sequences more intelligently. This new packing strategy is now the default when packing is enabled.
training_args = SFTConfig(..., packing=True)
by @qgallouedec in #3521 and accelerated by @mariosasko in #3537
[Liger] liger DPO support
The DPOTrainer
now supports the Liger-powered DPO loss, enabling faster training with lower memory usage.
training_args = DPOConfig(..., use_liger_loss=True)
💬 Fix setup_chat_format
and add clone_chat_template
We introduce clone_chat_template
, a more convenient and flexible function for setting up chat templates from any tokenizer that already includes one. It handles EOS tokens and copies all added tokens from the source tokenizer, preserving their "special" status.
You can either use this function directly:
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import clone_chat_template
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m")
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-4B")
or use the chat_template_path
parameter in SFTConfig
to specify a chat template, which will be automatically cloned when the SFTTrainer is initialized.
from trl import SFTConfig
training_args = SFTConfig(chat_template_path="Qwen/Qwen3-4B")
by @qgallouedec in #3404 and #3599
📚 SFTTrainer support chat template kwargs
SFTTrainer
now supports passing additional keyword arguments to the chat template. This allows for more flexibility in customizing the chat format during training. To enable it, just add a chat_template_kwargs
column to your your dataset.
example = {'messages': [{'content': 'What is better than ugly?', 'role': 'user'},
{'content': 'Beautiful.', 'role': 'assistant'}]
'chat_template_kwargs': {'my_template_arg': 'my_value'}}
by @qgallouedec in #3609
🤵♂️ SFT on assistant messages only
The SFTTrainer
now supports training on assistant messages only
example = {'messages': [
{'role': 'user', 'content': 'What is better than ugly?'}, # masked in the loss
{'role': 'assistant', 'content': 'Beautiful.'}, # used in the loss
{'role': 'user', 'content': 'And what is better than implicit?'}, # masked in the loss
{'role': 'assistant', 'content': 'Explicit.'}, # used in the loss
]}
by @qgallouedec in #3586
🧬 Add generation_kwargs
as a property of GRPOConfig
to support additional generation arguments
The GRPOConfig
now includes a generation_kwargs
property, allowing users to specify additional generation arguments for the GRPOTrainer
. This allows for further customization of the generation behavior, such as setting suppress_tokens
, num_beams
, etc.
Depending on the generation backend used (transformers or vLLM), this property will be passed either to transformers.GenerationConfig
(if using transformers) or vllm.SamplingParams
(if using vLLM).
from trl import GRPOConfig
training_args = GRPOConfig(..., generation_kwargs={"length_penalty": -0.1})
by @pramodith in #3617
New defaults
- 🎀 New default:
beta=0.0
for GRPO by @qgallouedec in #3516 - 🎀 New defaults: preparing the new structure by @qgallouedec in #3530
- 🎀 New defaults:
logging_steps=10
by @qgallouedec in #3514 - 🎀 [SFT][Bugfix] sets average_tokens_across_devices to true in SFTConfig by @edbeeching in #3538
- 🎀 New defaults:
bf16=True
by @qgallouedec in #3515
Minor changes
- Add support for
IterableDataset
in DPO Trainer by @h-tonywu in #3559 - 🔖 Fix: ensure user-provided
labels
are retained in self._signature_columns by @sxndqc in #3589 - ⭐ Add
vllm_gpu_memory_utilization
recommendation script by @toslali-ibm in #3554
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #3505
- 📎 Fix clip ratio logging by @qgallouedec in #3506
- 📚 Fix doc building by removing vLLM from dev dependencies in
setup.cfg
by @qgallouedec in #3511 - 🧭 Patch release guide by @qgallouedec in #3512
- 🎀 New default:
beta=0.0
for GRPO by @qgallouedec in #3516 - Add "🐯 Liger GRPO meets TRL" by @qgallouedec in #3525
- 📉 FFD packing by @qgallouedec in #3521
- 🎀 New defaults: preparing the new structure by @qgallouedec in #3530
- 🪦 RIP trl chat by @shirinyamani in #3531
- 🎀 New defaults:
logging_steps=10
by @qgallouedec in #3514 - 📰 Add blog "No GPU left behind: Unlocking Efficiency with Co-located vLLM in TRL" by @qgallouedec in #3527
- 🎯 Don't use
getattr
to getgradient_checkpointing
by @qgallouedec in #3535 - 🧭 Remove useless transformers version checks by @qgallouedec in #3534
- 🐳 Add DeepseekV3 model configurations and update tests for new models by @qgallouedec in #3536
- 💭 [Data] Fix DeepSeek-R1 case by @kashif in #3522
- 🎀 [SFT][Bugfix] sets average_tokens_across_devices to true in SFTConfig by @edbeeching in #3538
- ⚡ Faster FFD packing by @mariosasko in #3537
- 📦 Packing with flash attn kwargs to avoid cross-contamination by @thepowerfuldeez in #3526
- 💽 [TRLParser] Fail when unknown args are provided in the config file. by @edbeeching in #3543
- 🛋️ Fix CI and bump accelerate by @qgallouedec in #3551
- 🧮 Rearrange DPOTrainer by @DaizeDong in #3501
- 🆙 Bump transformers to 4.51 and use
_VALID_DICT_FIELDS
by @qgallouedec in #3553 - Update tests_latest.yml by @qgallouedec in #3558
- ℹ️ Unify autocast behavior to
torch.autocast
and make it cover XPU by @yao-matrix in #3541 - Fix dev version by @Tavish9 in #3570
- [Lig...
Assets 2
v0.18.2
Compare
What's Changed
- 🏗️ Add test for training with multiple dataloader workers and update worker initialization for compatibility with transformers 4.52.0 by @qgallouedec in #3568
Full Changelog: v0.18.1...v0.18.2
Assets 2
v0.18.1
Compare
What's Changed
- 📎 Fix clip ratio logging by @qgallouedec in #3506
- 📚 Fix doc building by removing vLLM from dev dependencies in
setup.cfg
by @qgallouedec in #3511
Full Changelog: v0.18.0...v0.18.1
Assets 2
v0.18.0
ef4b0b2
Compare
Major or breaking
- PEFT support for Liger GRPO by @SalmanMohammadi in #3355
- [🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP by @shivam15s in #3260
- 🤝 Compatibility of the TRL CLI with accelerate arguments by @qgallouedec in #3409
- 🧑🤝🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization by @toslali-ibm in #3394
- ✌️ Add support for FSDP2 by @lewtun in #3317
- 💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated by @edbeeching in #3388
- [Models] Activation checkpointing from TorchTune by @kashif in #2954
- feat: Implement Two-Sided Clipping for GRPO Trainer by @ucalyptus in #3434
- 🎁 Reward submodule by @qgallouedec in #3430
What's Changed
- ⬆️ Bump dev version by @qgallouedec in #3357
- 🔢 Pad to multiple of by @qgallouedec in #3362
- 🥸🔢 Adding pad_multiple to SFT trainer by @shirinyamani in #3365
- 🎭 Fix train and eval mode checking in
GRPOTrainer
andSFTTrainer
by @I-l-l-I in #3337 - Better guards for DeepSpeed imports by @lewtun in #3351
- ⚰️ Remove deprecated by @qgallouedec in #3364
- 📋 Allow calling trl cli in sft mode with config file by @CloseChoice in #3380
- PEFT support for Liger GRPO by @SalmanMohammadi in #3355
- DPO fixes for evaluations by @winglian in #3377
- Deprecate
TextEnvironment
and tools by @lewtun in #3389 - [🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP by @shivam15s in #3260
- [GRPO] Reference model initialization bug fix by @LeonEricsson in #3397
- 🌊 Add MLflow metrics in profiling context by @dhruvmullick in #3400
- 🧑🤝🧑 Co-Locating vLLM w/ training to for higher throughput and GPU utilization by @toslali-ibm in #3394
- ✨ [IterativeSFT] Small refresher by @LeonEricsson in #3378
- 💔 [SFT] Raise error when
formatting_func
is used withcompletion_only_loss
by @LeonEricsson in #3385 - 🦁 Fix liger initialization by @shivam15s in #3401
- 👉 [DPO] Model forward pass padding side fix by @LeonEricsson in #3307
- 🪪 Remove license classifier by @qgallouedec in #3402
- 🕺 Migrate setup configuration from
setup.py
tosetup.cfg
and makerich
an optional dep by @qgallouedec in #3403 - 🕊️ Un-restrict diffusers by @qgallouedec in #3407
- ✌️ Add support for FSDP2 by @lewtun in #3317
- 🤝 Compatibility of the TRL CLI with accelerate arguments by @qgallouedec in #3409
- 💔 [GRPO] Decouple gradient accumulation from the number of minibatches generated by @edbeeching in #3388
- 🎲 [GRPO] Shuffle mini batches by @edbeeching in #3391
- 📝 vLLM-integration documentation by @shirinyamani in #3376
- 🎁 Reward takes completion ids by @qgallouedec in #3272
- 🐍 Support Python 3.13 by @qgallouedec in #2593
- [Models] Activation checkpointing from TorchTune by @kashif in #2954
- 🧪 Testing support for Qwen3 tiny by @shirinyamani in #3415
- Update README.md by @qgallouedec in #3420
- 🏹 Support kv_cache_dtype to quantize kv-cache in vllm by @winglian in #3422
- enable
trl env
on xpu by @yao-matrix in #3438 - use device agnostic empty_cache in ppo & rloo by @yao-matrix in #3439
- feat: Implement Two-Sided Clipping for GRPO Trainer by @ucalyptus in #3434
- 🎁 Reward submodule by @qgallouedec in #3430
- [CI] fix CI failure of transformer dev by @kashif in #3457
- enable vllm c-s tests on XPU by @yao-matrix in #3445
- enable activation offloading on XPU by @yao-matrix in #3444
- 🙅 PPO value_model can't be None, so it shouldn't be Optional by @AMindToThink in #3300
- [NashMD] fix the edge case where the model is a peft model by @kashif in #3473
- Update .pre-commit-config.yaml by @kashif in #3479
- [SFT] update minimal liger version by @kashif in #3483
- [CI] fix sampler api to make the CI green by @kashif in #3488
- Fix typo by @nikolai-kummer in #3489
- [Doc][SFT] Update sft_trainer.md. link prompt-completion dataset example by @HERIUN in #3486
- Fix mis-aligned prompts and completions in colocate mode by @toslali-ibm in #3491
- [Docs] sync logging doc to current metrics by @kashif in #3478
- [GRPO] disabling top_k sampling default by @kashif in #3494
- [GKD] fix the gkd script by @kashif in #3497
- 👇 Update grpo.py to fix bugs for cli grpo --reward_funcs my_lib.my_reward by @wa008 in #3454
- 🛠️ Initialize reward_kwargs to prevent UnboundLocalError in GRPOTrainer by @teilomillet in #3459
- 🐌 Clean two-sided clipping by @qgallouedec in #3499
- 🔭 [GRPO] Log advantages and fraction of samples with an std of zero by @edbeeching in #3502
- 📏 Completion length logging fix + remainder logging fix by @shirinyamani in #3482
- 🤧 LD-DPO support by @AIR-hl in #3458
- 🏰 [vllm] Support
base_url
parameter for vLLM client initialization by @re-imagined in #3324 - ✂️ [DPO] Fix truncation
keep_end
leading to zero'd out samples by @LeonEricsson in #3398 - Release: v0.18 by @qgallouedec in #3504
New Contributors
- @CloseChoice made their first contribution in #3380
- @SalmanMohammadi made their first contribution in #3355
- @dhruvmullick made their first contribution in #3400
- @toslali-ibm made their first contribution in #3394
- @yao-matrix made their first contribution in #3438
- @nikolai-kummer made their first contribution in #3489
- @wa008 made their first contribution in #3454
- @teilomillet made their first contribution in #3459
- @re-imagined made their first contribution in #3324
Full Changelog: v0.17.0...v0.18.0
Assets 2
v0.17.0
cd6b3de
Compare
Major and breaking
The TRL v0.17 release introduces three major changes that, together, enable significantly faster generation performance in GRPO—up to 10x faster in some configurations.
These three changes are:
- Data parallelism (DP) for the vLLM server
- A new GRPO training strategy that generates once per effective batch
- Support for the V1 engine in vLLM
Below, we provide a summary of these changes and how to use them.
⚡ Up to 4x faster: Data Parallel for vLLM server
The TRL vLLM server now supports data parallelism (DP), enabling significantly faster generation speeds—especially for smaller models. This new feature can be used by adding the --data_parallel_size N
argument when launching the vLLM server.
trl vllm-serve --model Qwen/Qwen2.5-14B-Instruct --tensor_parallel_size 2 --data_parallel_size 2
by @qgallouedec in #3310
* ☝️ [GRPO] Generate once per effective batch
Previously, GRPO made one generation request per global batch. The global batch is the total of all local batches, without accounting for gradient accumulation. In other words, if the gradient accumulation step was 8, GRPO would make 8 generation requests per training step.
Now, GRPO groups these global batches into a single "effective batch" and makes only one generation request per effective batch. Since vLLM applies optimizations that are especially effective for large batches, this new approach leads to significantly faster training overall.
No changes are required in the training script, as this is handled internally by the GRPO trainer.
by @qgallouedec in #3283
⏱️ Fix vLLM server to support V1 Engine
vLLM provides two versions of its engine (V0 and V1), and V1 is significantly faster. This version is now supported by TRL and requires vLLM version 0.8.3 or higher.
👎 [GRPO] Adds option to disable dropout
Disabling dropout has shown to stabilize training. You can now disable dropout in GRPO by setting the disable_dropout
argument to False
in the GRPO config.
from trl import GRPOConfig
training_args = GRPOConfig(..., disable_dropout=True)
by @edbeeching in #3234
🩺 Dr. GRPO loss
GRPO now supports the various losses proposed in the recent literature, including the Dr. GRPO loss. The loss type can be set in the GRPO config:
from trl import GRPOConfig
training_args = GRPOConfig(..., loss_type="dr_grpo")
by @qgallouedec in #3256
🎲 [GRPO] Make training dataset shuffle optional
The GRPO trainer now has an option to disable shuffling of the training dataset. This is useful for curriculum learning, where the order of the training data is important.
from trl import GRPOConfig
training_args = GRPOConfig(..., shuffle_dataset=False)
by @LeonEricsson in #3334
☕ Overlong-filtering for GRPO
Overlong filtering has been shown to significantly stabilize learning and improve performance. You can now use it in TRL!
It simply consists in masking the loss of truncated samples
from trl import GRPOConfig
training_args = GRPOConfig(..., mask_truncated_completions=True)
by @shirinyamani in #3248
🐯 Integrate Liger GRPO Loss to GRPO Trainer
Liger allows to significantly reduce the memory peak of the loss computation. You can now use it in TRL with the use_liger_loss
argument in the GRPO config:
from trl import GRPOConfig
training_args = GRPOConfig(..., use_liger_loss=True)
by @shivam15s in #3184
Bug fixes
- Fix: Multi gpu hang for ORPO and CPO Trainer by @NanoCode012 in #3069
- 📊 Fix
clip_ratio
logging and better document logged values by @qgallouedec in #3145 - ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint by @PenutChen in #3148
- 📎 Fix is_clipped to compute the effective clip_ratio by @pandong2011 in #3175
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- ⏯️ Fix logging when resuming from checkpoint GRPO by @qgallouedec in #3185
- 💠 Fix multi-gpu padding free by @qgallouedec in #3245
- 🕷 Fix online DPO crash when model is a DataParallel object by @wilrop in #3225
- 🏁 Fix adding special tokens in SFT by @qgallouedec in #3328
- 🍡 Fix using reward model and DeepSpeed ZeRO 3 by @qgallouedec in #3326
What's Changed
- Fix: Multi gpu hang for ORPO and CPO Trainer by @NanoCode012 in #3069
- 📊 Fix
clip_ratio
logging and better document logged values by @qgallouedec in #3145 - BCOTrainer version upgrade fixes by @claralp in #2867
- 🐇 [Research] Layer Skip SFT by @ariG23498 in #3111
- 🤝 Align GRPO equation doc with the implementation by @qgallouedec in #3151
- Enable number of printed completions to be set by @lewtun in #3149
- 🩹 Fix CI by @qgallouedec in #3155
- ⚰️ Remove deprecated by @qgallouedec in #3153
- 🔫 Disable triggering CI when PR is draft by @qgallouedec in #3154
- 👨🍳 vLLM serve: destroy process group on exit and pass
worker_cls
as string by @qgallouedec in #3159 - 💰 Richer rich table - log all the rewards by @qgallouedec in #3156
- 💎 Gemma 3 VLM SFT example script for single-image and multi-image by @sergiopaniego in #3131
- [Liger] Liger KTO support by @vaibhavjindal in #2812
- 🏃 Migrate CI to self-hosted runners by @qgallouedec in #3174
- ❤️🩹 [CI] fix transformers dev CI failure by @kashif in #3176
- ⏯️ Fix: handle None inputs when resuming GRPO Trainer from checkpoint by @PenutChen in #3148
- 📎 Fix is_clipped to compute the effective clip_ratio by @pandong2011 in #3175
- Fix breaking typo for flash_attention reducing_memory_usage.md by @burtenshaw in #3190
- Show unique prompts in GRPO WandB tables by @lewtun in #3191
- 🐗 [CI] Fix trufflehog false positives by @lewtun in #3192
- [GRPO] Improve completion length logging by @edbeeching in #3188
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- 🗝️ Fix type hint in vLLM client by @qgallouedec in #3205
- 📚 Accumulate completions for logging by @lewtun in #3217
- Group completion metrics by common prefix by @lewtun in #3212
- 🐯 Integrate Liger GRPO Loss to GRPO Trainer by @shivam15s in #3184
- Update ruff to 11.3 and base Python version to 3.9 by @cyyever in #3230
- ⏯️ Fix logging when resuming from checkpoint GRPO by @qgallouedec in #3185
- 📢 Improve GRPO trainer error message for invalid num_generations by @AliBakly in #3199
- 🎀 Simplify logging text by @qgallouedec in #3219
- 🌊 Add error for iterable datasets in GRPOTrainer by @qgallouedec in #3216
- ⏳ PPOTrainer: fix progress bar for num_mini_batches > 1 by @dawidm in #2531
- ☑ Update PULL_REQUEST_TEMPLATE.md by @qgallouedec in #3241
- 🔭 Add support for better KL estimator (k3) in PPOTrainer by @AMindToThink in #3240
- 🏃 Fix and make CI faster by @qgallouedec in #3160
- 🗑️ Deprecate
ConstantLengthDataset
by @qgallouedec in #3242 - 📦 [SFT] Deprecate batched
formatting_func
by @YeFD in #3147 - 💠 Fix multi-gpu padding free by @qgallouedec in #3245
- ☕ Overlong-filtering for GRPO by @shirinyamani in #3248
- 📜 Fix license and copyrights by @qgallouedec in #3264
- ⛏️ Add cli dict parsing for grpo_config by @Tavish9 in #3082
- 🐯
is_liger_kernel_available
with min version by @qgal...
Assets 2
v0.16.1
Compare
What's Changed
- 😷 Fix SFT masking EOS when equal to PAD by @qgallouedec in #3200
- 📉 Add
learning_rate
argument to_maybe_log_save_evaluate
by @qgallouedec in #3206
Full Changelog: v0.16.0...v0.16.1
Assets 2
v0.16.0
23a635e
Compare
Major and breaking
🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication
Previously, vLLM could only be used by dedicating a single GPU, preventing both the scalability benefits of vLLM and multi-node training. This limitation has now been removed!
GRPO can now scale efficiently with models exceeding 70B parameters, supporting multi-node training with super-fast performance.
To take advantage of this, simply launch a vLLM server using the following command:
trl vllm-serve --model <model_name> --tensor_parallel_size <tp_size>
Then, start GRPO training with use_vllm=True
.
Below is a comparison of GRPO throughput with and without vLLM, across different TP values and model sizes.
@binary-husky and @qgallouedec in #3094
🐦🔥 6x faster GRPO with multi-step optimization
This release introduces the multi-step trick, which allows for the reuse of generated data across multiple steps, speeding up the training process.
To support this, we've implemented importance sampling and clipping logic. This enhancement should lead to significant improvements in training speed.

To use it, simply set num_iterations
to a value greater than 1.
training_args = GRPOConfig(..., num_iterations=4)
by @qgallouedec in #2899
🌍 Use global normalization in GRPO
As demonstrated in Dr GRPO, sequence-level normalization can introduce a response level length bias.
To address this, we have now switched to normalizing the loss and by the total number of tokens in the batch, ensuring more consistent and unbiased training.
- loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
by @edbeeching in #2881
⚖️ Add option not to scale rewards
As demonstrated in Dr GRPO, scaling rewards can introduce a question-level difficulty bias. To address this, we have now added an option to disable reward scaling in GRPO.
training_args = GRPOConfig(..., scale_rewards=False)
advantages = rewards - mean_grouped_rewards
- advantages = advantages / std_grouped_rewards
+ if self.args.scale_rewards:
+ advantages = advantages / std_grouped_rewards
it's likely that we'll make this (scale_rewards=False
) the default behavior in the future.
by @qgallouedec in #3135
🤸♀️ Domain-specific rewards in GRPO
When optimizing across multiple domains, not all reward functions are relevant for every sample. For example, a math verifier's reward does not apply to grammar samples, and a grammar verifier's reward does not apply to math samples.
It is now possible to return None
for rewards that do not make sense for a given sample. For instance, when the domain is specified in a column like domain
, you can implement it as follows:
def math_reward(completions, domain, **kwargs):
rewards = []
for completion, dom in zip(completions, domain):
if dom == "math":
rewards.append(verify(completion))
else:
rewards.append(None)
return rewards
This allows for more domain-specific reward handling, ensuring that irrelevant rewards are ignored and don’t interfere with optimization.
by @shirinyamani in #3079
🍃 Do not load reference model when beta == 0.0
It has been observed that not minimizing the KL divergence between the trained model and the reference model can still yield good results, while significantly reducing memory usage and compute. This is because there is no need to store the reference model in memory or perform a forward pass for it.
When beta
is set to 0.0
, the reference model is not loaded, and the KL divergence is not computed, leading to savings in both time and memory.
training_args = GRPOConfig(..., beta=0.0)
🕊️ Padding-free for SFT
Padding-free batching is an alternative approach to packing for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
To enable padding-free batching in SFT, simply set padding_free=True
in the SFTConfig
, and make sure to use flash_attention2
as the attention implementation.
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention2"})
by @qgallouedec in #3076
🎬 Clip Higher for Better Exploration
As outlined in the DAPO paper, increasing the upper bound epsilon leads to higher entropy during generation, promoting better exploration. To enable this, we’ve added support for adjusting the upper bound epsilon directly in the default GRPO trainer.
training_args = GRPOConfig(epsilon_high=0.28)
by @shirinyamani in #3118
Bug fixes
- 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 🐯 Fix LigerKernel for SFTTrainer by @lewtun @kashif and @qgallouedec in #2874, #2940 and #2949
- 🫔 [GRPO] Pass wrapped model to
unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility by @kiddj in #2871 - 🛣️
inference_mode
tono_grad
when computingold_per_token_logps
by @qgallouedec in #2987 - 🏊 [SFT] Compatibility with padding free and iterable dataset by @qgallouedec in #3053
- Fixing JSD loss computation in GKDTrainer as per definition by @abhigoyal1997 in #3043
Minor
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW and @DanFosing in #2863 and #2939
- ✨ Add vLLM guided decoding support to GRPO Trainer by @kldzj in #2811
- 🩳
max_seq_length
tomax_length
by @qgallouedec in #2895 and #2947 - Optimize vllm num_generations by @edbeeching in #2855
- 📍 [GRPO] add gradient_checkpointing by @kashif in #2848
- 🪪 Adds profiling decorators for GRPOTrainer by @edbeeching in #2889 and #2975
- 🐈 Bye bye chat by @qgallouedec in #2934
- 📇 GRPO: print completions to console and update docs by @nopepper in #2951
- 👧🏽 Adding DoRA support to model config by @nbasyl in #2974
- 🧗 Add GRPO Trainer support for third-party accelerators by @ji-huazhong in #2836
- 🪙 [SFT] Log
num_tokens
and some logging fixes by @qgallouedec in #3006 - 🌡️ Fix temperature inconsistency in GRPO trainer by @Aladoro in #3029
- ⛔ Add EOS token to processed input in SFT by @qgallouedec in #3091
- ⚡ Pack 300 times faster, truncate 100 times faster by @mariosasko in #3009
What's Changed
- [SFT] fix check for AutoLigerKernelForCausalLM by @kashif in #2874
- 🆙 Bump vLLM min version to 0.7.2 by @edbeeching in #2860
- [GRPO] Fix loss normalization by @edbeeching in #2881
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW in #2863
- Optimize vllm num_generations ...
Assets 2
v0.15.2
Compare
What changed
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 🐯 Fix LigerKernel for SFTTrainer by @lewtun in #2940
- 📌 Pin liger-kernel and vLLM by @qgallouedec in #2952
Full Changelog: v0.15.1...v0.15.2
Assets 2
v0.15.1
Compare
What's Changed
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets by @kashif in SFT in #2862 - [SFT] fix check for AutoLigerKernelForCausalLM by @kashif in #2874
- 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW in #2863
- 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- Release: v0.15.1 by @qgallouedec
Full Changelog: v0.15.0...v0.15.1