Skip to content

support gemma4#1304

Open
WANDY666 wants to merge 7 commits into
mainfrom
support_gemma4
Open

support gemma4#1304
WANDY666 wants to merge 7 commits into
mainfrom
support_gemma4

Conversation

@WANDY666
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for the Gemma-4 model, including its multimodal and MoE architectures. It introduces Gemma-4 specific inference structures, layer implementations, and weight loading mechanisms. Significant updates were made to the Triton attention kernels to support sliding window attention and larger head dimensions, and the MoE kernels were extended to support GELU activation and per-expert scaling. The review feedback suggests optimizing configuration flattening to avoid the overhead of AutoConfig, pre-allocating buffers for KV-shared layers to prevent memory fragmentation, and replacing Python-based RMSNorm implementations with Triton kernels for better performance. There is also a recommendation to avoid unnecessary memory copies during weight loading.

# under text_config; flatten it so downstream code sees text-model fields
# at the top level (mirrors the gemma3 approach).
if "text_config" in self.config:
hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using AutoConfig.from_pretrained inside _init_config can be quite heavy and may trigger unnecessary network checks or slow processing on every rank in a distributed environment. Since config.json is already loaded into self.config at line 129, you can flatten the configuration by directly accessing the text_config key if it exists. This avoids the overhead of the transformers AutoConfig class while achieving the same goal for downstream code.

Suggested change
hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True)
if "text_config" in self.config:
self.config = self.config["text_config"]

# K/V come from target layer's already-rotated, already-normed cache.
# Only rotate Q here. rotary_emb_fwd writes to k in place, so pass
# a 1-head throwaway tensor we can discard.
dummy_k = torch.empty((q.shape[0], 1, head_dim), dtype=q.dtype, device=q.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Allocating a dummy_k tensor using torch.empty on every call to _get_qkv for KV-shared layers can lead to memory fragmentation over time, especially during long generation sequences. Consider pre-allocating a sufficiently large buffer in the infer_state or reusing a shared workspace to avoid repeated allocations.

Comment on lines +147 to +149
v_fp = v.float()
v_fp = v_fp * torch.rsqrt(v_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps_)
v = v_fp.to(input.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The unweighted RMSNorm for the V tensor is implemented in Python. While correct, performing v.pow(2).mean(...) followed by rsqrt and multiplication on the CPU/GPU via standard PyTorch ops creates multiple temporary tensors. For long sequences in prefill mode, this can cause memory spikes and performance bottlenecks. It would be more efficient to use a dedicated Triton kernel for this normalization step, similar to how RMSNormWeight is handled.

Comment on lines +21 to +24
gate_weight = gate_up_weight[expert_idx, start:end, :].contiguous()
up_weight = gate_up_weight[
expert_idx, moe_intermediate_size + start : moe_intermediate_size + end, :
].contiguous()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .contiguous() on slices of gate_up_weight creates additional copies of the expert weights in memory during the loading process. If the model has a large number of experts or a high intermediate dimension, this could significantly increase the peak memory usage of the loader. If self.quant_method.load_weight can handle non-contiguous tensors, these calls should be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant