Skip to content

gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling#502

Open
bigximik wants to merge 4 commits intogrpo-metricsfrom
gspo
Open

gspo: GSPO loss, docs_per_step accumulation, normalize_by_documents, temperature scaling#502
bigximik wants to merge 4 commits intogrpo-metricsfrom
gspo

Conversation

@bigximik
Copy link
Copy Markdown
Collaborator

Summary

This PR adds a suite of RL training improvements to the GRPO/GSPO loss infrastructure, targeting the grpo-metrics branch. It contains four logical units:

1. GSPO loss (sequence-level IS-ratio clipping)

Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo".

  • New fused_gspo_loss_forward_backward kernel: computes per-segment geometric-mean log-ratio R_s, clips at [1−ε_low, 1+ε_high], and applies R_s × A_s as a uniform per-token gradient within each segment. An all_reduce(SUM) over sequence-data-parallel ranks aggregates (lrn_sum, adv_sum, tok_count) before clipping so the ratio is correct under sequence parallelism.
  • New document_index data field and LanguageModelKwargs.document_index kwarg constant to route per-token segment membership through the data pipeline.
  • 8 unit tests in tests/layers/test_gspo_loss.py (single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).

2. Dynamic docs_per_step accumulation

Replaces static depth_first_micro_batches with a runtime document-count target — matching DeepSpeed's gradient_accumulation_passes semantics for RL (where each microbatch holds one rollout).

  • ScheduleConfig.docs_per_step: when >0, Trainer._prefetch_to_doc_target fetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥ docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.
  • Trainer._get_or_build_schedule builds and caches a per-N Schedule with _depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.
  • Schedule._eff_{depth_first,sequential,num_inputs} properties expose the effective values for a given override.
  • 13 unit tests in tests/layers/test_docs_per_step.py.

3. normalize_by_documents

Adds a normalize_by_documents flag to LanguageModelGRPOLossConfig. When True, both the GRPO and GSPO paths divide the loss by num_documents_in_batch (the step-level rollout count) rather than the token count. This matches DeepSpeed's normalization where tokens_weights = 1 / gradient_accumulation_passes, making per-step gradient magnitudes comparable across the two implementations.

4. Temperature scaling for IS ratio parity

Adds a temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in _forward_backward. Default temperature=1.0 preserves existing behaviour exactly.

Test plan

  • pytest tests/layers/test_gspo_loss.py — GSPO unit tests
  • pytest tests/layers/test_docs_per_step.py — docs_per_step unit tests
  • pytest tests/layers/test_lm_losses.py — existing GRPO loss tests unaffected
  • pytest tests/layers/test_grpo_metrics.py — metrics tests unaffected
  • End-to-end: 4-node Qwen2.5-7B math run with docs_per_step=1024, temperature=0.7, normalize_by_documents=true — verify grpo_kl_new_old ≈ 1e-4 at step 1, stable training through 400 steps

Implements GSPO (geometric-mean sequence-level policy-gradient loss) as
an alternative to the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".

Key changes:
- data pipeline: expose per-token document_index when return_document_index=True
- LanguageModelKwargs.document_index: new kwarg constant
- LanguageModelLoss: store SDP dim for cross-rank segment aggregation
- grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across
  SDP ranks before computing segment-level R_s and A_s; gradient derivation
  exploits tok_count cancellation so every token in a segment gets the
  same gradient factor R_s * clip_indicator_s
- tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed,
  ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff,
  per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).

Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.

YAML usage:
  schedule:
    rollouts_per_step: 1024   # replaces manual depth_first_micro_batches
  model:
    distributed:
      data_parallel: 8        # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
  is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
  properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
  all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
  then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
  _depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
  both GRPO and GSPO paths divide by num_documents_in_batch instead of
  num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
  scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
  and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match
the actor's sampling temperature (e.g. 0.7), new log-probs are computed
at the same temperature as the stored old log-probs, so the IS ratio
starts near 1.0 instead of ~1.08.

Implementation: _effective_logits_scale = logits_scale_factor / temperature,
substituted for logits_scale_factor at all three callsites in
_forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default
temperature=1.0 preserves existing behaviour exactly.
@bigximik bigximik requested a review from jlamypoirier April 29, 2026 08:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant