Support lightx2v_train#1059
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a comprehensive training framework for image and video generation models, supporting LoRA and full fine-tuning for LongCat and Qwen architectures. The implementation includes a registry-based system for models, datasets, and trainers, along with a flow-matching scheduler and checkpoint management. Key feedback includes addressing unreliable floating-point equality checks in the scheduler, fixing an off-by-one error in the checkpoint pruning logic, and optimizing performance by caching latent constants. Additionally, suggestions were made to allow configurable batch sizes in the data loader and to refactor the registry class for better inheritance practices.
| ) | ||
| return DataLoader( | ||
| dataset, | ||
| batch_size=1, |
There was a problem hiding this comment.
The batch_size is hardcoded to 1 in the DataLoader. This will significantly limit training throughput and GPU utilization, even when using gradient accumulation. It should be configurable via the data_config_split to allow for larger batches.
| batch_size=1, | |
| batch_size=data_config_split.get("batch_size", 1), |
| sigmas = self._train_sigmas.to(device=self.device, dtype=dtype) | ||
| schedule_timesteps = self._train_timesteps.to(self.device) | ||
| timesteps = timesteps.to(self.device) | ||
| sigma_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] |
There was a problem hiding this comment.
Floating-point equality comparison schedule_timesteps == t is unreliable and may fail due to precision issues. Additionally, nonzero().item() will raise a RuntimeError if no exact match is found. Using torch.argmin with absolute difference is more robust for finding the closest index in the schedule.
| sigma_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] | |
| sigma_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps] |
| self.timesteps = (sigmas * self.num_train_timesteps).to(device) | ||
|
|
||
| def step(self, model_output, timestep, sample, return_dict=True): | ||
| step_index = (self.timesteps == timestep).nonzero()[0].item() |
There was a problem hiding this comment.
Similar to the issue in get_sigmas, using == for floating-point timesteps is unreliable. Use a more robust matching method like torch.argmin on the absolute difference.
| step_index = (self.timesteps == timestep).nonzero()[0].item() | |
| step_index = torch.argmin(torch.abs(self.timesteps - timestep)).item() |
| latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | ||
| latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | ||
| return (latent - latent_mean) * latent_std |
There was a problem hiding this comment.
latent_mean and latent_std are reconstructed as new tensors on every call to encode_to_latent. This adds unnecessary overhead. These constants should be computed once and stored as instance attributes to avoid redundant allocations and transfers.
| latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | |
| latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | |
| return (latent - latent_mean) * latent_std | |
| if not hasattr(self, "_latent_mean"): | |
| self._latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | |
| self._latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1) | |
| return (latent - self._latent_mean) * self._latent_std |
| if len(checkpoints) < total_limit: | ||
| return | ||
|
|
||
| for name in checkpoints[: len(checkpoints) - total_limit + 1]: |
There was a problem hiding this comment.
There is an off-by-one error in the pruning logic. If len(checkpoints) is equal to total_limit, the current code will delete one checkpoint, leaving total_limit - 1 checkpoints. The slice should be [: len(checkpoints) - total_limit] to correctly keep exactly total_limit checkpoints.
| for name in checkpoints[: len(checkpoints) - total_limit + 1]: | |
| for name in checkpoints[: len(checkpoints) - total_limit]: |
| class Register(dict): | ||
| def __init__(self, *args, **kwargs): | ||
| super(Register, self).__init__(*args, **kwargs) | ||
| self._dict = {} | ||
|
|
||
| def __call__(self, target_or_name): | ||
| if callable(target_or_name): | ||
| return self.register(target_or_name) | ||
| else: | ||
| return lambda x: self.register(x, key=target_or_name) | ||
|
|
||
| def register(self, target, key=None): | ||
| if not callable(target): | ||
| raise Exception(f"Error: {target} must be callable!") | ||
|
|
||
| if key is None: | ||
| key = target.__name__ | ||
|
|
||
| if key in self._dict: | ||
| raise Exception(f"{key} already exists.") | ||
|
|
||
| self[key] = target | ||
| return target | ||
|
|
||
| def __setitem__(self, key, value): | ||
| self._dict[key] = value | ||
|
|
||
| def __getitem__(self, key): | ||
| return self._dict[key] | ||
|
|
||
| def __contains__(self, key): | ||
| return key in self._dict | ||
|
|
||
| def __str__(self): | ||
| return str(self._dict) | ||
|
|
||
| def keys(self): | ||
| return self._dict.keys() | ||
|
|
||
| def values(self): | ||
| return self._dict.values() | ||
|
|
||
| def items(self): | ||
| return self._dict.items() | ||
|
|
||
| def get(self, key, default=None): | ||
| return self._dict.get(key, default) | ||
|
|
||
| def merge(self, other_register): | ||
| for key, value in other_register.items(): | ||
| if key in self._dict: | ||
| raise Exception(f"{key} already exists in target register.") | ||
| self[key] = value |
There was a problem hiding this comment.
The Register class implementation is redundant and potentially confusing. It inherits from dict but maintains an internal self._dict that it uses for most operations, while the base dict remains empty. It is cleaner to inherit from dict and use self directly. Additionally, use more specific exceptions like ValueError or KeyError instead of the generic Exception.
class Register(dict):
def __call__(self, target_or_name):
if callable(target_or_name):
return self.register(target_or_name)
else:
return lambda x: self.register(x, key=target_or_name)
def register(self, target, key=None):
if not callable(target):
raise ValueError(f"Error: {target} must be callable!")
if key is None:
key = target.__name__
if key in self:
raise KeyError(f"{key} already exists.")
self[key] = target
return target
def merge(self, other_register):
for key, value in other_register.items():
if key in self:
raise KeyError(f"{key} already exists in target register.")
self[key] = value
No description provided.