Skip to content

[Others]Prefill-Decode Batch Invariant in glm#8000

Open
bukejiyu wants to merge 1 commit into
PaddlePaddle:developfrom
bukejiyu:pd_batch_invariant
Open

[Others]Prefill-Decode Batch Invariant in glm#8000
bukejiyu wants to merge 1 commit into
PaddlePaddle:developfrom
bukejiyu:pd_batch_invariant

Conversation

@bukejiyu
Copy link
Copy Markdown
Collaborator

@bukejiyu bukejiyu commented Jun 4, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

增加GLM PD一致性 监控单测

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 Paddle-CI-Agent | pr_review | 2026-06-04 14:53:13

📋 Review 摘要

PR 概述:为 GLM 模型增加 Prefill-Decode batch 一致性监控单测,同步修复量化 kernel 数值精度、batch_invariant 线性层猴子补丁及确定性模式 warmup 重置逻辑。
变更范围custom_ops/gpu_ops/fastdeploy/model_executor/layers/fastdeploy/worker/fastdeploy/output/tests/
影响面 Tag[Models] [OP] [Quantization] [DataProcessor] [CI]

问题

级别 文件 概述
🟡 建议 tests/model_loader/utils.py:137 form_model_get_output_logprobs 调用 fd_runner(llm_params) 而非 fd_runner(**llm_params),函数在被调用时将始终失败
🟡 建议 fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py:700 linear_v2_batch_invariant 中形状判断 weight.shape[0] == K 在方阵(K==N)时存在歧义,可能静默地选错方向
🟡 建议 fastdeploy/worker/gpu_worker.py:265 elif isinstance(share_inputs, dict) 分支为死代码,share_inputs 始终是 InputBatch 实例;同时保留了注释掉的原始调用,增加维护歧义
🟡 建议 tests/deterministic/test_glm_pd_consistency.py:9 测试文件顶层使用 sys.path.insert,属于环境依赖 hack(checklist §C 表层信号)

📝 PR 规范检查

PR 标题为 [Others]Prefill-Decode Batch Invariant in glm,标题 Tag 与 [ 后缺少空格(建议加空格),且 [Others] 与实际变更有偏差——diff 同时涉及模型层、量化 kernel、测试,建议改用更精确的 Tag 组合。Motivation 段落为空(保留了模板注释但未填写内容),Usage or CommandAccuracy Tests 均未填写(应注明 N/A)。Checklist 条目均未勾选。

标题建议(可直接复制):

  • [PD Disaggregation] Add GLM prefill-decode batch invariant consistency test
PR 描述建议(点击展开,可直接复制)
## Motivation
增加 GLM 模型 Prefill-Decode 输出一致性(batch invariant)监控单测,确保 PD 分离场景下 logprobs 一致性可被持续回归验证。同步修复量化 kernel 中除法改倒数乘法的数值精度问题,以及 `batch_invariant` 模式下对 `linear`/`linear_v2` 算子的猴子补丁缺失。

## Modifications
- `custom_ops/gpu_ops/per_token_quant_fp8.cu`:将量化 kernel 中的除法替换为 `__frcp_rn` 倒数乘法,提升确定性精度
- `fastdeploy/envs.py`:新增 `FD_SKIP_IN_DETERMINISTIC` 环境变量控制单测中跳过 attention 层
- `fastdeploy/model_executor/layers/attention/attention.py`:新增 `skip_attn` 参数,确定性模式下跳过 attention 计算
- `fastdeploy/model_executor/layers/batch_invariant_ops/batch_invariant_ops.py`:补充 `linear`/`linear_v2` 的 batch invariant 替换,修复 `_compute_pid` 内联
- `fastdeploy/model_executor/models/glm4_moe.py`:传入 `skip_attn` 参数
- `fastdeploy/output/token_processor.py`:输出中补充 `prompt_token_ids` 字段
- `fastdeploy/worker/gpu_worker.py`:确定性模式 warmup 后兼容性重置 share_inputs
- `tests/`:新增 GLM PD 一致性单测及辅助工具函数

## Usage or Command
N/A

## Accuracy Tests
N/A

## Checklist

- [x] Add at least a tag in the PR title.
  - Tag list: [`[FDConfig]`,`[APIServer]`,`[Engine]`, `[Scheduler]`, `[PD Disaggregation]`, `[Executor]`, `[Graph Optimization]`, `[Speculative Decoding]`, `[RL]`, `[Models]`, `[Quantization]`, `[Loader]`, `[OP]`, `[KVCache]`, `[DataProcessor]`, `[BugFix]`, `[Docs]`, `[CI]`, `[Optimization]`, `[Feature]`, `[Benchmark]`, `[Others]`, `[XPU]`, `[HPU]`, `[GCU]`, `[DCU]`, `[Iluvatar]`, `[Metax]`]
  - You can add new tags based on the PR content, but the semantics must be clear.
- [ ] Format your code, run `pre-commit` before commit.
- [x] Add unit tests. Please write the reason in this PR if no unit tests.
- [ ] Provide accuracy results.
- [ ] If the current PR is submitting to the `release` branch, make sure the PR has been submitted to the `develop` branch, then cherry-pick it to the `release` branch with the `[Cherry-Pick]` PR tag.

总体评价

PR 核心逻辑清晰,PD 一致性测试思路合理,量化 kernel 的倒数优化方向正确。主要问题集中在测试辅助函数的接口调用错误(form_model_get_output_logprobs 永远无法正常运行)以及 gpu_worker.py 中引入了不必要的死代码分支,建议修复后合入。

result_queue: 结果队列
"""
try:
with fd_runner(llm_params) as fd_model:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 fd_runner(llm_params) 将整个字典作为第一个位置参数(model_name_or_path: str)传入,导致函数调用始终失败或行为异常。

同文件中 form_model_pd_consistency_test 第 169 行已使用正确的解包写法 fd_runner(**llm_params)

建议修复:

# 将
with fd_runner(llm_params) as fd_model:
# 改为
with fd_runner(**llm_params) as fd_model:

此外,异常处理中引用了 llm_params['model_path'],而实际 key 应为 model_name_or_path,也需一并修正。


def linear_v2_batch_invariant(x, weight, bias=None, weight_transposed=False):
"""Drop-in replacement for paddle._C_ops.linear_v2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 形状判断 weight.shape[0] == K 用于推断 weight 是否已经是 [K, N],但当矩阵为方阵(K == N)时无法区分两种布局,会静默地走 weight is already [K, N] 分支,可能产生错误的矩阵乘法。

建议直接使用 weight_transposed 参数进行判断(该参数已由调用方传入),或与 Paddle flag FLAGS_use_accuracy_compatible_kernel 对齐:

def linear_v2_batch_invariant(x, weight, bias=None, weight_transposed=False):
    # 依据 weight_transposed 标志而非形状猜测
    if weight_transposed:
        # weight is [K, N] (已由调用方转置)
        return matmul_persistent(x, weight, bias)
    else:
        # weight is [N, K],需要转置
        return matmul_persistent(x, weight.T, bias)

weight_transposed 语义在不同 Paddle 版本下不一致,至少应在方阵场景下加断言或警告。

# 创建一个临时的 InputBatch 来借用它的逻辑
from fastdeploy.worker.input_batch import InputBatch

temp_batch = InputBatch.__new__(InputBatch)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 gpu_model_runner.py:226self.share_inputs = InputBatch(self.fd_config) 确认 share_inputs 始终是 InputBatch 实例,因此 elif isinstance(self.model_runner.share_inputs, dict) 分支为死代码,在生产中永远不会执行。

另外,被注释掉的原始调用 # self.model_runner.share_inputs.reset_share_inputs() 应一并删除,避免维护歧义。

建议简化为:

if envs.FD_DETERMINISTIC_MODE:
    set_random_seed(self.fd_config.model_config.seed)
    if hasattr(self.model_runner.share_inputs, 'reset_share_inputs'):
        self.model_runner.share_inputs.reset_share_inputs()

#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 测试文件顶层使用 sys.path.insert 注入路径,属于环境依赖 hack(checklist §C 表层信号)。其他同目录测试文件(如 tests/model_loader/ 下)通常依赖 conftest.pypytest.ini 中的 pythonpath 配置统一管理路径。

建议通过在 tests/ 根目录的 conftest.pypyproject.toml / pytest.ini 中配置 pythonpath = tests 来替代,移除测试文件内的 sys.path 操作。

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Jun 4, 2026

Codecov Report

❌ Patch coverage is 39.53488% with 26 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@fab344e). Learn more about missing BASE report.

Files with missing lines Patch % Lines
.../layers/batch_invariant_ops/batch_invariant_ops.py 42.85% 14 Missing and 2 partials ⚠️
fastdeploy/worker/gpu_worker.py 9.09% 9 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #8000   +/-   ##
==========================================
  Coverage           ?   67.85%           
==========================================
  Files              ?      468           
  Lines              ?    65954           
  Branches           ?    10171           
==========================================
  Hits               ?    44756           
  Misses             ?    18359           
  Partials           ?     2839           
Flag Coverage Δ
GPU 78.01% <39.53%> (?)
XPU 7.02% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants