diff --git a/configs/prompts/tuneavideo.yaml b/configs/prompts/tuneavideo.yaml new file mode 100644 index 0000000..481da13 --- /dev/null +++ b/configs/prompts/tuneavideo.yaml @@ -0,0 +1,44 @@ +pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" +output_dir: "output/output" + +train_data: + video_path: "output/dataset/video.mp4" + prompt: "buses passing by the park" + width: 512 + height: 512 + n_sample_frames: 8 + sample_start_idx: 0 + sample_frame_rate: 1 + +validation_data: + prompts: + - "validation_prompt_1" + - "validation_prompt_2" + num_inv_steps: 10 + use_inv_latent: true + +trainable_modules: + - "to_q" + +train_batch_size: 4 +max_train_steps: 1000 +learning_rate: 3e-5 +scale_lr: false +lr_scheduler: "constant" +lr_warmup_steps: 0 +adam_beta1: 0.9 +adam_beta2: 0.999 +adam_weight_decay: 1e-2 +adam_epsilon: 1e-08 +max_grad_norm: 1.0 +gradient_accumulation_steps: 1 +gradient_checkpointing: true +checkpointing_steps: 500 +resume_from_checkpoint: null +mixed_precision: "fp16" +use_8bit_adam: false +enable_xformers_memory_efficient_attention: true +seed: 1234 + +motion_module: "models/Motion_Module/mm_sd_v15.ckpt" +inference_config_path: "configs/inference/inference.yaml" diff --git a/train.py b/train.py index ee76d4f..279d4f6 100644 --- a/train.py +++ b/train.py @@ -1,387 +1,386 @@ -import argparse -import datetime -import logging -import inspect -import math -import os -from typing import Dict, Optional, Tuple -from omegaconf import OmegaConf -from collections import OrderedDict - -import torch -import torch.nn.functional as F -import torch.utils.checkpoint - -import diffusers -import transformers -from accelerate import Accelerator -from accelerate.logging import get_logger -from accelerate.utils import set_seed -from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler -from diffusers.optimization import get_scheduler -from diffusers.utils import check_min_version -from diffusers.utils.import_utils import is_xformers_available -from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer - -from animatediff.models.unet import UNet3DConditionModel -from tuneavideo.data.dataset import TuneAVideoDataset -from animatediff.pipelines.pipeline_animation import AnimationPipeline -from tuneavideo.util import save_videos_grid, ddim_inversion -from einops import rearrange - - -# Will error if the minimal version of diffusers is not installed. Remove at your own risks. -check_min_version("0.10.0.dev0") - -logger = get_logger(__name__, log_level="INFO") - - -def main( - pretrained_model_path: str, - output_dir: str, - train_data: Dict, - validation_data: Dict, - validation_steps: int = 100, - trainable_modules: Tuple[str] = ( - "to_q", - ), - train_batch_size: int = 1, - max_train_steps: int = 500, - learning_rate: float = 3e-5, - scale_lr: bool = False, - lr_scheduler: str = "constant", - lr_warmup_steps: int = 0, - adam_beta1: float = 0.9, - adam_beta2: float = 0.999, - adam_weight_decay: float = 1e-2, - adam_epsilon: float = 1e-08, - max_grad_norm: float = 1.0, - gradient_accumulation_steps: int = 1, - gradient_checkpointing: bool = True, - checkpointing_steps: int = 500, - resume_from_checkpoint: Optional[str] = None, - mixed_precision: Optional[str] = "fp16", - use_8bit_adam: bool = False, - enable_xformers_memory_efficient_attention: bool = True, - seed: Optional[int] = None, - - motion_module: str = "models/Motion_Module/mm_sd_v15.ckpt", - inference_config_path: str = "configs/inference/inference.yaml", -): - *_, config = inspect.getargvalues(inspect.currentframe()) - - inference_config = OmegaConf.load(inference_config_path) - - accelerator = Accelerator( - gradient_accumulation_steps=gradient_accumulation_steps, - mixed_precision=mixed_precision, - ) - - # Make one log on every process with the configuration for debugging. - logging.basicConfig( - format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", - level=logging.INFO, - ) - logger.info(accelerator.state, main_process_only=False) - if accelerator.is_local_main_process: - transformers.utils.logging.set_verbosity_warning() - diffusers.utils.logging.set_verbosity_info() - else: - transformers.utils.logging.set_verbosity_error() - diffusers.utils.logging.set_verbosity_error() - - # If passed along, set the training seed now. - if seed is not None: - set_seed(seed) - - # Handle the output folder creation - if accelerator.is_main_process: - # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - # output_dir = os.path.join(output_dir, now) - os.makedirs(output_dir, exist_ok=True) - os.makedirs(f"{output_dir}/samples", exist_ok=True) - os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) - OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) - - # Load scheduler, tokenizer and models. - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") - tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") - vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") - unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) - - motion_module_state_dict = torch.load(motion_module, map_location="cpu") - if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) - missing, unexpected = unet.load_state_dict(motion_module_state_dict, strict=False) - assert len(unexpected) == 0 - - # Freeze vae and text_encoder - vae.requires_grad_(False) - text_encoder.requires_grad_(False) - - unet.requires_grad_(False) - for name, module in unet.named_modules(): - if "motion_modules" in name and name.endswith(tuple(trainable_modules)): - for params in module.parameters(): - params.requires_grad = True - - if enable_xformers_memory_efficient_attention: - if is_xformers_available(): - unet.enable_xformers_memory_efficient_attention() - else: - raise ValueError("xformers is not available. Make sure it is installed correctly") - - if gradient_checkpointing: - unet.enable_gradient_checkpointing() - - if scale_lr: - learning_rate = ( - learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes - ) - - # Initialize the optimizer - if use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" - ) - - optimizer_cls = bnb.optim.AdamW8bit - else: - optimizer_cls = torch.optim.AdamW - - optimizer = optimizer_cls( - unet.parameters(), - lr=learning_rate, - betas=(adam_beta1, adam_beta2), - weight_decay=adam_weight_decay, - eps=adam_epsilon, - ) - - # Get the training dataset - train_dataset = TuneAVideoDataset(**train_data) - - # Preprocessing the dataset - train_dataset.prompt_ids = tokenizer( - train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" - ).input_ids[0] - - # DataLoaders creation: - train_dataloader = torch.utils.data.DataLoader( - train_dataset, batch_size=train_batch_size - ) - - # Get the validation pipeline - validation_pipeline = AnimationPipeline( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, - scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), - ) - validation_pipeline.enable_vae_slicing() - ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') - ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps) - - # Scheduler - lr_scheduler = get_scheduler( - lr_scheduler, - optimizer=optimizer, - num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, - num_training_steps=max_train_steps * gradient_accumulation_steps, - ) - - # Prepare everything with our `accelerator`. - unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - unet, optimizer, train_dataloader, lr_scheduler - ) - - # For mixed precision training we cast the text_encoder and vae weights to half-precision - # as these models are only used for inference, keeping weights in full precision is not required. - weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 - - # Move text_encode and vae to gpu and cast to weight_dtype - text_encoder.to(accelerator.device, dtype=weight_dtype) - vae.to(accelerator.device, dtype=weight_dtype) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - # Afterwards we recalculate our number of training epochs - num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) - - # We need to initialize the trackers we use, and also store our configuration. - # The trackers initializes automatically on the main process. - if accelerator.is_main_process: - accelerator.init_trackers("text2video-fine-tune") - - # Train! - total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_train_steps}") - global_step = 0 - first_epoch = 0 - - # Potentially load in the weights and states from a previous save - if resume_from_checkpoint: - if resume_from_checkpoint != "latest": - path = os.path.basename(resume_from_checkpoint) - else: - # Get the most recent checkpoint - dirs = os.listdir(output_dir) - dirs = [d for d in dirs if d.startswith("checkpoint")] - dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) - path = dirs[-1] - accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(output_dir, path)) - global_step = int(path.split("-")[1]) - - first_epoch = global_step // num_update_steps_per_epoch - resume_step = global_step % num_update_steps_per_epoch - - # Only show the progress bar once on each machine. - progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) - progress_bar.set_description("Steps") - - for epoch in range(first_epoch, num_train_epochs): - unet.train() - train_loss = 0.0 - for step, batch in enumerate(train_dataloader): - # Skip steps until we reach the resumed step - if resume_from_checkpoint and epoch == first_epoch and step < resume_step: - if step % gradient_accumulation_steps == 0: - progress_bar.update(1) - continue - - with accelerator.accumulate(unet): - # Convert videos to latent space - pixel_values = batch["pixel_values"].to(weight_dtype) - video_length = pixel_values.shape[1] - pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") - latents = vae.encode(pixel_values).latent_dist.sample() - latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each video - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["prompt_ids"])[0] - - # Get the target for loss depending on the prediction type - if noise_scheduler.prediction_type == "epsilon": - target = noise - elif noise_scheduler.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") - - # Predict the noise residual and compute loss - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - - # Gather the losses across all processes for logging (if we use distributed training). - avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() - train_loss += avg_loss.item() / gradient_accumulation_steps - - # Backpropagate - accelerator.backward(loss) - if accelerator.sync_gradients: - accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) - optimizer.step() - lr_scheduler.step() - optimizer.zero_grad() - - # Checks if the accelerator has performed an optimization step behind the scenes - if accelerator.sync_gradients: - progress_bar.update(1) - global_step += 1 - accelerator.log({"train_loss": train_loss}, step=global_step) - train_loss = 0.0 - - if global_step % checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = os.path.join(output_dir, f"mm-{global_step}.pth") - save_checkpoint(unet, save_path) - logger.info(f"Saved state to {save_path}") - - if global_step % validation_steps == 0: - if accelerator.is_main_process: - samples = [] - generator = torch.Generator(device=latents.device) - generator.manual_seed(seed) - - ddim_inv_latent = None - if validation_data.use_inv_latent: - inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt") - ddim_inv_latent = ddim_inversion( - validation_pipeline, ddim_inv_scheduler, video_latent=latents, - num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype) - torch.save(ddim_inv_latent, inv_latents_path) - - for idx, prompt in enumerate(validation_data.prompts): - sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent, - **validation_data).videos - save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") - samples.append(sample) - samples = torch.concat(samples) - save_path = f"{output_dir}/samples/sample-{global_step}.gif" - save_videos_grid(samples, save_path) - logger.info(f"Saved samples to {save_path}") - - logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} - progress_bar.set_postfix(**logs) - - if global_step >= max_train_steps: - break - - # Create the pipeline using the trained modules and save it. - accelerator.wait_for_everyone() - if accelerator.is_main_process: - unet = accelerator.unwrap_model(unet) - pipeline = AnimationPipeline.from_pretrained( - pretrained_model_path, - text_encoder=text_encoder, - vae=vae, - unet=unet, - ) - - mm_path = "%s/mm.pth" % output_dir - save_checkpoint(unet, mm_path) - - accelerator.end_training() - -def save_checkpoint(unet, mm_path): - mm_state_dict = OrderedDict() - state_dict = unet.state_dict() - for key in state_dict: - if "motion_module" in key: - mm_state_dict[key] = state_dict[key] - - torch.save(mm_state_dict, mm_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml") - args = parser.parse_args() - - main(**OmegaConf.load(args.config)) +import argparse +import datetime +import logging +import inspect +import math +import os +from typing import Dict, Optional, Tuple +from omegaconf import OmegaConf +from collections import OrderedDict + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint + +import diffusers +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, DDIMScheduler +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version +from diffusers.utils.import_utils import is_xformers_available +from tqdm.auto import tqdm +from transformers import CLIPTextModel, CLIPTokenizer + +from animatediff.models.unet import UNet3DConditionModel +from tuneavideo.data.dataset import TuneAVideoDataset +from animatediff.pipelines.pipeline_animation import AnimationPipeline +from tuneavideo.util import save_videos_grid, ddim_inversion +from einops import rearrange + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.10.0.dev0") + +logger = get_logger(__name__, log_level="INFO") + + +def main( + pretrained_model_path: str, + output_dir: str, + train_data: Dict, + validation_data: Dict, + validation_steps: int = 100, + trainable_modules: Tuple[str] = ( + "to_q", + ), + train_batch_size: int = 1, + max_train_steps: int = 500, + learning_rate: float = 3e-5, + scale_lr: bool = False, + lr_scheduler: str = "constant", + lr_warmup_steps: int = 0, + adam_beta1: float = 0.9, + adam_beta2: float = 0.999, + adam_weight_decay: float = 1e-2, + adam_epsilon: float = 1e-08, + max_grad_norm: float = 1.0, + gradient_accumulation_steps: int = 1, + gradient_checkpointing: bool = True, + checkpointing_steps: int = 500, + resume_from_checkpoint: Optional[str] = None, + mixed_precision: Optional[str] = "fp16", + use_8bit_adam: bool = False, + enable_xformers_memory_efficient_attention: bool = True, + seed: Optional[int] = None, + + motion_module: str = "models/Motion_Module/mm_sd_v15.ckpt", + inference_config_path: str = "configs/inference/inference.yaml", +): + *_, config = inspect.getargvalues(inspect.currentframe()) + + inference_config = OmegaConf.load(inference_config_path) + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision=mixed_precision, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if seed is not None: + set_seed(seed) + + # Handle the output folder creation + if accelerator.is_main_process: + # now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + # output_dir = os.path.join(output_dir, now) + os.makedirs(output_dir, exist_ok=True) + os.makedirs(f"{output_dir}/samples", exist_ok=True) + os.makedirs(f"{output_dir}/inv_latents", exist_ok=True) + OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) + + # Load scheduler, tokenizer and models. + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") + text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") + vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)) + + motion_module_state_dict = torch.load(motion_module, map_location="cpu") + if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]}) + missing, unexpected = unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + + # Freeze vae and text_encoder + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + + unet.requires_grad_(False) + for name, module in unet.named_modules(): + if "motion_modules" in name and name.endswith(tuple(trainable_modules)): + for params in module.parameters(): + params.requires_grad = True + + if enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if gradient_checkpointing: + unet.enable_gradient_checkpointing() + + if scale_lr: + learning_rate = ( + learning_rate * gradient_accumulation_steps * train_batch_size * accelerator.num_processes + ) + + # Initialize the optimizer + if use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + unet.parameters(), + lr=learning_rate, + betas=(adam_beta1, adam_beta2), + weight_decay=adam_weight_decay, + eps=adam_epsilon, + ) + + # Get the training dataset + train_dataset = TuneAVideoDataset(**train_data) + + # Preprocessing the dataset + train_dataset.prompt_ids = tokenizer( + train_dataset.prompt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" + ).input_ids[0] + + # DataLoaders creation: + train_dataloader = torch.utils.data.DataLoader( + train_dataset, batch_size=train_batch_size + ) + + # Get the validation pipeline + validation_pipeline = AnimationPipeline( + vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, + scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), + ) + validation_pipeline.enable_vae_slicing() + ddim_inv_scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder='scheduler') + ddim_inv_scheduler.set_timesteps(validation_data.num_inv_steps) + + # Scheduler + lr_scheduler = get_scheduler( + lr_scheduler, + optimizer=optimizer, + num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, + num_training_steps=max_train_steps * gradient_accumulation_steps, + ) + + # Prepare everything with our `accelerator`. + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu and cast to weight_dtype + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) + # Afterwards we recalculate our number of training epochs + num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("text2video-fine-tune") + + # Train! + total_batch_size = train_batch_size * accelerator.num_processes * gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num Epochs = {num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if resume_from_checkpoint: + if resume_from_checkpoint != "latest": + path = os.path.basename(resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(output_dir, path)) + global_step = int(path.split("-")[1]) + + first_epoch = global_step // num_update_steps_per_epoch + resume_step = global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, num_train_epochs): + unet.train() + train_loss = 0.0 + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % gradient_accumulation_steps == 0: + progress_bar.update(1) + continue + + with accelerator.accumulate(unet): + # Convert videos to latent space + pixel_values = batch["pixel_values"].to(weight_dtype) + video_length = pixel_values.shape[1] + pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w") + latents = vae.encode(pixel_values).latent_dist.sample() + latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length) + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each video + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["prompt_ids"])[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.prediction_type == "epsilon": + target = noise + elif noise_scheduler.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.prediction_type}") + + # Predict the noise residual and compute loss + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Gather the losses across all processes for logging (if we use distributed training). + avg_loss = accelerator.gather(loss.repeat(train_batch_size)).mean() + train_loss += avg_loss.item() / gradient_accumulation_steps + + # Backpropagate + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(unet.parameters(), max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + accelerator.log({"train_loss": train_loss}, step=global_step) + train_loss = 0.0 + + if global_step % checkpointing_steps == 0: + if accelerator.is_main_process: + save_path = os.path.join(output_dir, f"mm-{global_step}.pth") + save_checkpoint(unet, save_path) + logger.info(f"Saved state to {save_path}") + + if global_step % validation_steps == 0: + if accelerator.is_main_process: + samples = [] + generator = torch.Generator(device=latents.device) + generator.manual_seed(seed) + + ddim_inv_latent = None + if validation_data.use_inv_latent: + inv_latents_path = os.path.join(output_dir, f"inv_latents/ddim_latent-{global_step}.pt") + ddim_inv_latent = ddim_inversion( + validation_pipeline, ddim_inv_scheduler, video_latent=latents, + num_inv_steps=validation_data.num_inv_steps, prompt="")[-1].to(weight_dtype) + torch.save(ddim_inv_latent, inv_latents_path) + + for idx, prompt in enumerate(validation_data.prompts): + sample = validation_pipeline(prompt, generator=generator, latents=ddim_inv_latent, video_length=video_length, **validation_data).videos + save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif") + samples.append(sample) + samples = torch.concat(samples) + save_path = f"{output_dir}/samples/sample-{global_step}.gif" + save_videos_grid(samples, save_path) + logger.info(f"Saved samples to {save_path}") + + logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + + if global_step >= max_train_steps: + break + + # Create the pipeline using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = accelerator.unwrap_model(unet) + pipeline = AnimationPipeline.from_pretrained( + pretrained_model_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + ) + + mm_path = "%s/mm.pth" % output_dir + save_checkpoint(unet, mm_path) + + accelerator.end_training() + +def save_checkpoint(unet, mm_path): + mm_state_dict = OrderedDict() + state_dict = unet.state_dict() + for key in state_dict: + if "motion_module" in key: + mm_state_dict[key] = state_dict[key] + + torch.save(mm_state_dict, mm_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="./configs/tuneavideo.yaml") + args = parser.parse_args() + + main(**OmegaConf.load(args.config)) diff --git a/tuneavideo/data/dataset.py b/tuneavideo/data/dataset.py index 7c714aa..7ac4193 100644 --- a/tuneavideo/data/dataset.py +++ b/tuneavideo/data/dataset.py @@ -1,44 +1,45 @@ -import decord -decord.bridge.set_bridge('torch') - -from torch.utils.data import Dataset -from einops import rearrange - - -class TuneAVideoDataset(Dataset): - def __init__( - self, - video_path: str, - prompt: str, - width: int = 512, - height: int = 512, - n_sample_frames: int = 8, - sample_start_idx: int = 0, - sample_frame_rate: int = 1, - ): - self.video_path = video_path - self.prompt = prompt - self.prompt_ids = None - - self.width = width - self.height = height - self.n_sample_frames = n_sample_frames - self.sample_start_idx = sample_start_idx - self.sample_frame_rate = sample_frame_rate - - def __len__(self): - return 1 - - def __getitem__(self, index): - # load and sample video frames - vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) - sample_index = list(range(self.sample_start_idx, len(vr), self.sample_frame_rate))[:self.n_sample_frames] - video = vr.get_batch(sample_index) - video = rearrange(video, "f h w c -> f c h w") - - example = { - "pixel_values": (video / 127.5 - 1.0), - "prompt_ids": self.prompt_ids - } - - return example +import decord +decord.bridge.set_bridge('torch') + +from torch.utils.data import Dataset +from einops import rearrange + + +class TuneAVideoDataset(Dataset): + def __init__( + self, + video_path: str, + prompt: str, + width: int = 512, + height: int = 512, + n_sample_frames: int = 8, + sample_start_idx: int = 0, + sample_frame_rate: int = 1, + ): + self.video_path = video_path + self.prompt = prompt + self.prompt_ids = None + + self.width = width + self.height = height + self.n_sample_frames = n_sample_frames + self.sample_start_idx = sample_start_idx + self.sample_frame_rate = sample_frame_rate + + def __len__(self): + return 1 + + def __getitem__(self, index): + # load and sample video frames + vr = decord.VideoReader(self.video_path, width=self.width, height=self.height) + video_length = len(vr) + sample_index = list(range(self.sample_start_idx, video_length, self.sample_frame_rate))[:self.n_sample_frames] + video = vr.get_batch(sample_index) + video = rearrange(video, "f h w c -> f c h w") + + example = { + "pixel_values": (video / 127.5 - 1.0), + "prompt_ids": self.prompt_ids, + } + + return example