support gemma4#1304
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for the Gemma-4 model, including its multimodal and MoE architectures. It introduces Gemma-4 specific inference structures, layer implementations, and weight loading mechanisms. Significant updates were made to the Triton attention kernels to support sliding window attention and larger head dimensions, and the MoE kernels were extended to support GELU activation and per-expert scaling. The review feedback suggests optimizing configuration flattening to avoid the overhead of AutoConfig, pre-allocating buffers for KV-shared layers to prevent memory fragmentation, and replacing Python-based RMSNorm implementations with Triton kernels for better performance. There is also a recommendation to avoid unnecessary memory copies during weight loading.
| # under text_config; flatten it so downstream code sees text-model fields | ||
| # at the top level (mirrors the gemma3 approach). | ||
| if "text_config" in self.config: | ||
| hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True) |
There was a problem hiding this comment.
Using AutoConfig.from_pretrained inside _init_config can be quite heavy and may trigger unnecessary network checks or slow processing on every rank in a distributed environment. Since config.json is already loaded into self.config at line 129, you can flatten the configuration by directly accessing the text_config key if it exists. This avoids the overhead of the transformers AutoConfig class while achieving the same goal for downstream code.
| hf_config = AutoConfig.from_pretrained(self.weight_dir_, trust_remote_code=True) | |
| if "text_config" in self.config: | |
| self.config = self.config["text_config"] |
| # K/V come from target layer's already-rotated, already-normed cache. | ||
| # Only rotate Q here. rotary_emb_fwd writes to k in place, so pass | ||
| # a 1-head throwaway tensor we can discard. | ||
| dummy_k = torch.empty((q.shape[0], 1, head_dim), dtype=q.dtype, device=q.device) |
There was a problem hiding this comment.
Allocating a dummy_k tensor using torch.empty on every call to _get_qkv for KV-shared layers can lead to memory fragmentation over time, especially during long generation sequences. Consider pre-allocating a sufficiently large buffer in the infer_state or reusing a shared workspace to avoid repeated allocations.
| v_fp = v.float() | ||
| v_fp = v_fp * torch.rsqrt(v_fp.pow(2).mean(dim=-1, keepdim=True) + self.eps_) | ||
| v = v_fp.to(input.dtype) |
There was a problem hiding this comment.
The unweighted RMSNorm for the V tensor is implemented in Python. While correct, performing v.pow(2).mean(...) followed by rsqrt and multiplication on the CPU/GPU via standard PyTorch ops creates multiple temporary tensors. For long sequences in prefill mode, this can cause memory spikes and performance bottlenecks. It would be more efficient to use a dedicated Triton kernel for this normalization step, similar to how RMSNormWeight is handled.
| gate_weight = gate_up_weight[expert_idx, start:end, :].contiguous() | ||
| up_weight = gate_up_weight[ | ||
| expert_idx, moe_intermediate_size + start : moe_intermediate_size + end, : | ||
| ].contiguous() |
There was a problem hiding this comment.
Calling .contiguous() on slices of gate_up_weight creates additional copies of the expert weights in memory during the loading process. If the model has a large number of experts or a high intermediate dimension, this could significantly increase the peak memory usage of the loader. If self.quant_method.load_weight can handle non-contiguous tensors, these calls should be removed.
No description provided.