Skip to content

Support lightx2v_train#1059

Open
helloyongyang wants to merge 9 commits into
ModelTC:mainfrom
helloyongyang:main
Open

Support lightx2v_train#1059
helloyongyang wants to merge 9 commits into
ModelTC:mainfrom
helloyongyang:main

Conversation

@helloyongyang
Copy link
Copy Markdown
Contributor

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a comprehensive training framework for image and video generation models, supporting LoRA and full fine-tuning for LongCat and Qwen architectures. The implementation includes a registry-based system for models, datasets, and trainers, along with a flow-matching scheduler and checkpoint management. Key feedback includes addressing unreliable floating-point equality checks in the scheduler, fixing an off-by-one error in the checkpoint pruning logic, and optimizing performance by caching latent constants. Additionally, suggestions were made to allow configurable batch sizes in the data loader and to refactor the registry class for better inheritance practices.

)
return DataLoader(
dataset,
batch_size=1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The batch_size is hardcoded to 1 in the DataLoader. This will significantly limit training throughput and GPU utilization, even when using gradient accumulation. It should be configurable via the data_config_split to allow for larger batches.

Suggested change
batch_size=1,
batch_size=data_config_split.get("batch_size", 1),

sigmas = self._train_sigmas.to(device=self.device, dtype=dtype)
schedule_timesteps = self._train_timesteps.to(self.device)
timesteps = timesteps.to(self.device)
sigma_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Floating-point equality comparison schedule_timesteps == t is unreliable and may fail due to precision issues. Additionally, nonzero().item() will raise a RuntimeError if no exact match is found. Using torch.argmin with absolute difference is more robust for finding the closest index in the schedule.

Suggested change
sigma_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
sigma_indices = [torch.argmin(torch.abs(schedule_timesteps - t)).item() for t in timesteps]

self.timesteps = (sigmas * self.num_train_timesteps).to(device)

def step(self, model_output, timestep, sample, return_dict=True):
step_index = (self.timesteps == timestep).nonzero()[0].item()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to the issue in get_sigmas, using == for floating-point timesteps is unreliable. Use a more robust matching method like torch.argmin on the absolute difference.

Suggested change
step_index = (self.timesteps == timestep).nonzero()[0].item()
step_index = torch.argmin(torch.abs(self.timesteps - timestep)).item()

Comment on lines +51 to +53
latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
return (latent - latent_mean) * latent_std
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

latent_mean and latent_std are reconstructed as new tensors on every call to encode_to_latent. This adds unnecessary overhead. These constants should be computed once and stored as instance attributes to avoid redundant allocations and transfers.

Suggested change
latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
return (latent - latent_mean) * latent_std
if not hasattr(self, "_latent_mean"):
self._latent_mean = torch.tensor(self.vae.config.latents_mean, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
self._latent_std = 1.0 / torch.tensor(self.vae.config.latents_std, device=self.device, dtype=self.dtype).view(1, 1, self.vae.config.z_dim, 1, 1)
return (latent - self._latent_mean) * self._latent_std

if len(checkpoints) < total_limit:
return

for name in checkpoints[: len(checkpoints) - total_limit + 1]:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is an off-by-one error in the pruning logic. If len(checkpoints) is equal to total_limit, the current code will delete one checkpoint, leaving total_limit - 1 checkpoints. The slice should be [: len(checkpoints) - total_limit] to correctly keep exactly total_limit checkpoints.

Suggested change
for name in checkpoints[: len(checkpoints) - total_limit + 1]:
for name in checkpoints[: len(checkpoints) - total_limit]:

Comment on lines +1 to +53
class Register(dict):
def __init__(self, *args, **kwargs):
super(Register, self).__init__(*args, **kwargs)
self._dict = {}

def __call__(self, target_or_name):
if callable(target_or_name):
return self.register(target_or_name)
else:
return lambda x: self.register(x, key=target_or_name)

def register(self, target, key=None):
if not callable(target):
raise Exception(f"Error: {target} must be callable!")

if key is None:
key = target.__name__

if key in self._dict:
raise Exception(f"{key} already exists.")

self[key] = target
return target

def __setitem__(self, key, value):
self._dict[key] = value

def __getitem__(self, key):
return self._dict[key]

def __contains__(self, key):
return key in self._dict

def __str__(self):
return str(self._dict)

def keys(self):
return self._dict.keys()

def values(self):
return self._dict.values()

def items(self):
return self._dict.items()

def get(self, key, default=None):
return self._dict.get(key, default)

def merge(self, other_register):
for key, value in other_register.items():
if key in self._dict:
raise Exception(f"{key} already exists in target register.")
self[key] = value
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Register class implementation is redundant and potentially confusing. It inherits from dict but maintains an internal self._dict that it uses for most operations, while the base dict remains empty. It is cleaner to inherit from dict and use self directly. Additionally, use more specific exceptions like ValueError or KeyError instead of the generic Exception.

class Register(dict):
    def __call__(self, target_or_name):
        if callable(target_or_name):
            return self.register(target_or_name)
        else:
            return lambda x: self.register(x, key=target_or_name)

    def register(self, target, key=None):
        if not callable(target):
            raise ValueError(f"Error: {target} must be callable!")

        if key is None:
            key = target.__name__

        if key in self:
            raise KeyError(f"{key} already exists.")

        self[key] = target
        return target

    def merge(self, other_register):
        for key, value in other_register.items():
            if key in self:
                raise KeyError(f"{key} already exists in target register.")
            self[key] = value

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant