gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling#502
Open
bigximik wants to merge 4 commits intogrpo-metricsfrom
Open
gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling#502bigximik wants to merge 4 commits intogrpo-metricsfrom
bigximik wants to merge 4 commits intogrpo-metricsfrom
Conversation
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).
Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.
YAML usage:
schedule:
rollouts_per_step: 1024 # replaces manual depth_first_micro_batches
model:
distributed:
data_parallel: 8 # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
_depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
both GRPO and GSPO paths divide by num_documents_in_batch instead of
num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds a suite of RL training improvements to the GRPO/GSPO loss infrastructure, targeting the
grpo-metricsbranch. It contains four logical units:1. GSPO loss (sequence-level IS-ratio clipping)
Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".fused_gspo_loss_forward_backwardkernel: computes per-segment geometric-mean log-ratioR_s, clips at[1−ε_low, 1+ε_high], and appliesR_s × A_sas a uniform per-token gradient within each segment. Anall_reduce(SUM)over sequence-data-parallel ranks aggregates(lrn_sum, adv_sum, tok_count)before clipping so the ratio is correct under sequence parallelism.document_indexdata field andLanguageModelKwargs.document_indexkwarg constant to route per-token segment membership through the data pipeline.tests/layers/test_gspo_loss.py(single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).2. Dynamic
docs_per_stepaccumulationReplaces static
depth_first_micro_batcheswith a runtime document-count target — matching DeepSpeed'sgradient_accumulation_passessemantics for RL (where each microbatch holds one rollout).ScheduleConfig.docs_per_step: when >0,Trainer._prefetch_to_doc_targetfetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.Trainer._get_or_build_schedulebuilds and caches a per-NSchedulewith_depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.Schedule._eff_{depth_first,sequential,num_inputs}properties expose the effective values for a given override.tests/layers/test_docs_per_step.py.3.
normalize_by_documentsAdds a
normalize_by_documentsflag toLanguageModelGRPOLossConfig. WhenTrue, both the GRPO and GSPO paths divide the loss bynum_documents_in_batch(the step-level rollout count) rather than the token count. This matches DeepSpeed's normalization wheretokens_weights = 1 / gradient_accumulation_passes, making per-step gradient magnitudes comparable across the two implementations.4. Temperature scaling for IS ratio parity
Adds a
temperaturefield toLanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation:_effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in_forward_backward. Defaulttemperature=1.0preserves existing behaviour exactly.Test plan
pytest tests/layers/test_gspo_loss.py— GSPO unit testspytest tests/layers/test_docs_per_step.py— docs_per_step unit testspytest tests/layers/test_lm_losses.py— existing GRPO loss tests unaffectedpytest tests/layers/test_grpo_metrics.py— metrics tests unaffecteddocs_per_step=1024,temperature=0.7,normalize_by_documents=true— verifygrpo_kl_new_old ≈ 1e-4at step 1, stable training through 400 steps